nano_gpt.datasets.hellaswag

Data loader library for the hellaswag dataset.

 1"""Data loader library for the hellaswag dataset."""
 2
 3from dataclasses import dataclass
 4from collections.abc import Iterable
 5import logging
 6from typing import Any
 7
 8import datasets
 9import torch
10
11from .data_loader import MapIterable
12from nano_gpt.tokenizer import Tokenizer
13
14_LOGGER = logging.getLogger(__name__)
15
16__all__ = [
17    "load_dataset",
18]
19
20
21NUM_ANSWERS = 4
22
23
24@dataclass(frozen=True, kw_only=True)
25class Sample:
26    """A sample multiple choice question."""
27
28    prefix: str
29    """The prefix of the question."""
30
31    endings: list[str]
32    """List of possible endings."""
33
34    label: int
35    """Index of the correct ending."""
36
37    @property
38    def ending_texts(self) -> list[str]:
39        """Return the completion candidates."""
40        return [f" {ending}" for ending in self.endings]
41
42    @property
43    def completions(self) -> list[str]:
44        """Return the completion candidates."""
45        return [f"{self.prefix} {ending}" for ending in self.endings]
46
47    @property
48    def max_len(self) -> int:
49        """Return the maximum length of the sample."""
50        return max(len(row) for row in self.completions)
51
52    def tokenize(self, tokenizer: Tokenizer) -> tuple[torch.Tensor, torch.Tensor]:
53        """Tokenize a sample and return the tokens and mask."""
54
55        max_len = self.max_len
56        prefix_toks = tokenizer.encode(self.prefix)
57        prefix_masks = [0] * len(prefix_toks)
58
59        tokens = torch.zeros((4, max_len), dtype=torch.long)
60        mask = torch.zeros((4, max_len), dtype=torch.long)
61        for i, ending in enumerate(self.ending_texts):
62            ending_toks = tokenizer.encode(ending)
63            ending_mask = [1] * len(ending_toks)
64            tok_row = prefix_toks + ending_toks
65            mask_row = prefix_masks + ending_mask
66            tokens[i, : len(tok_row)] = torch.tensor(tok_row)
67            mask[i, : len(mask_row)] = torch.tensor(mask_row)
68
69        return tokens, mask
70
71
72def _make_sample(record: dict[str, Any]) -> Sample:
73    return Sample(
74        prefix=record["ctx"], endings=record["endings"], label=int(record["label"])
75    )
76
77
78def load_dataset(split: str) -> Iterable[Sample]:
79    """Load the dataset."""
80    ds = datasets.load_dataset("Rowan/hellaswag", split=split)
81    return MapIterable(_make_sample, ds)
def load_dataset(split: str) -> Iterable[nano_gpt.datasets.hellaswag.Sample]:
79def load_dataset(split: str) -> Iterable[Sample]:
80    """Load the dataset."""
81    ds = datasets.load_dataset("Rowan/hellaswag", split=split)
82    return MapIterable(_make_sample, ds)

Load the dataset.