nano_gpt.datasets.hellaswag
Data loader library for the hellaswag dataset.
1"""Data loader library for the hellaswag dataset.""" 2 3from dataclasses import dataclass 4from collections.abc import Iterable 5import logging 6from typing import Any 7 8import datasets 9import torch 10 11from .data_loader import MapIterable 12from nano_gpt.tokenizer import Tokenizer 13 14_LOGGER = logging.getLogger(__name__) 15 16__all__ = [ 17 "load_dataset", 18] 19 20 21NUM_ANSWERS = 4 22 23 24@dataclass(frozen=True, kw_only=True) 25class Sample: 26 """A sample multiple choice question.""" 27 28 prefix: str 29 """The prefix of the question.""" 30 31 endings: list[str] 32 """List of possible endings.""" 33 34 label: int 35 """Index of the correct ending.""" 36 37 @property 38 def ending_texts(self) -> list[str]: 39 """Return the completion candidates.""" 40 return [f" {ending}" for ending in self.endings] 41 42 @property 43 def completions(self) -> list[str]: 44 """Return the completion candidates.""" 45 return [f"{self.prefix} {ending}" for ending in self.endings] 46 47 @property 48 def max_len(self) -> int: 49 """Return the maximum length of the sample.""" 50 return max(len(row) for row in self.completions) 51 52 def tokenize(self, tokenizer: Tokenizer) -> tuple[torch.Tensor, torch.Tensor]: 53 """Tokenize a sample and return the tokens and mask.""" 54 55 max_len = self.max_len 56 prefix_toks = tokenizer.encode(self.prefix) 57 prefix_masks = [0] * len(prefix_toks) 58 59 tokens = torch.zeros((4, max_len), dtype=torch.long) 60 mask = torch.zeros((4, max_len), dtype=torch.long) 61 for i, ending in enumerate(self.ending_texts): 62 ending_toks = tokenizer.encode(ending) 63 ending_mask = [1] * len(ending_toks) 64 tok_row = prefix_toks + ending_toks 65 mask_row = prefix_masks + ending_mask 66 tokens[i, : len(tok_row)] = torch.tensor(tok_row) 67 mask[i, : len(mask_row)] = torch.tensor(mask_row) 68 69 return tokens, mask 70 71 72def _make_sample(record: dict[str, Any]) -> Sample: 73 return Sample( 74 prefix=record["ctx"], endings=record["endings"], label=int(record["label"]) 75 ) 76 77 78def load_dataset(split: str) -> Iterable[Sample]: 79 """Load the dataset.""" 80 ds = datasets.load_dataset("Rowan/hellaswag", split=split) 81 return MapIterable(_make_sample, ds)
def
load_dataset(split: str) -> Iterable[nano_gpt.datasets.hellaswag.Sample]:
79def load_dataset(split: str) -> Iterable[Sample]: 80 """Load the dataset.""" 81 ds = datasets.load_dataset("Rowan/hellaswag", split=split) 82 return MapIterable(_make_sample, ds)
Load the dataset.