nano_gpt.config

Configuration module.

  1"""Configuration module."""
  2
  3import dataclasses
  4from dataclasses import dataclass
  5import enum
  6import logging
  7import pathlib
  8from typing import Protocol, Any
  9
 10import datasets
 11
 12__all__ = [
 13    "GPTConfig",
 14    "DatasetConfig",
 15    "TrainConfig",
 16    "SampleConfig",
 17    "EvalConfig",
 18    "TrainedModelConfig",
 19    "Models",
 20    "config_from",
 21    "model_config_from_pretrained",
 22    "TrainDataset",
 23    "LoadDataset",
 24]
 25
 26_LOGGER = logging.getLogger(__name__)
 27
 28VOCAB_SIZE = 50257  # Fixed size for GPT model checkpoints
 29NICE_VOCAB_SIZE = 50304  # Vocab size with nice power of 2, for training
 30BLOCK_SIZE = 1024  # Fixed size for GPT model checkpoints
 31DEFAULT_MICRO_BATCH_SIZE = 16
 32
 33
 34@dataclass(frozen=True, kw_only=True)
 35class GPTConfig:
 36    """This class defines the configuration for the GPT model.
 37
 38    This configuration is used for inference.
 39    """
 40
 41    block_size: int = BLOCK_SIZE
 42    """The maximum context length."""
 43
 44    vocab_size: int = NICE_VOCAB_SIZE
 45    """The size of the vocabulary."""
 46
 47    n_layer: int = 12
 48    """The number of transformer blocks."""
 49
 50    n_head: int = 12
 51    """The number of attention heads."""
 52
 53    n_embd: int = 768
 54    """The size of the embedding vector."""
 55
 56
 57@dataclass(frozen=True, kw_only=True)
 58class DatasetConfig:
 59    """This class defines the configuration for chunking the dataset."""
 60
 61    dataset_dir: str = "dataset_cache"
 62    """The directory where the dataset is stored."""
 63
 64    dataset_name: str = "tinyshakespeare"
 65    """The name of the dataset."""
 66
 67    micro_batch_size: int = DEFAULT_MICRO_BATCH_SIZE
 68    """Batch size (micro batch) (B) used for each forward/backward pass."""
 69
 70    sequence_length: int = BLOCK_SIZE
 71    """Sequence length (T) used for input content. Same as block_size."""
 72
 73    @property
 74    def chunk_token_size(self) -> int:
 75        """Number of tokens in each micro batch."""
 76        return self.micro_batch_size * self.sequence_length
 77
 78    def dataset_path(self, split: str) -> pathlib.Path:
 79        """Return the path to the dataset."""
 80        dataset_dir = pathlib.Path(self.dataset_dir)
 81        return dataset_dir / f"{self.dataset_name}_{split}.npy"
 82
 83
 84@dataclass(frozen=True, kw_only=True)
 85class SampleConfig:
 86    """This class defines the configuration for sampling the dataset."""
 87
 88    num_return_sequences: int = 5
 89    """The number of sequences to generate."""
 90
 91    max_length: int = 30
 92    """The maximum length of the generated sequences."""
 93
 94    text: str = "Hello, I'm a language model,"
 95    """The text to use as a prompt for sampling."""
 96
 97    seed: int = 42
 98    """The seed to use for sampling."""
 99
100
101@dataclass(frozen=True, kw_only=True)
102class EvalConfig:
103    """This class defines the configuration for the validation loss and HellaSwag eval."""
104
105    validation_steps: int = 20
106    """Number of validation loss iterations to perform each eval round."""
107
108    hellaswag_samples: int | None = None
109    """The number of HellaSwag evaluation results to sample or None for the entire set."""
110
111
112@dataclass(frozen=True, kw_only=True)
113class TrainConfig:
114    """Implementats the GPT-3 learning rate."""
115
116    seed: int = 1337
117    """The seed to use for training."""
118
119    step: int = 0
120    """The starting step to use for training."""
121
122    total_batch_size: int
123    """Total batch size in number of tokens for each gradient update.
124
125    If this is larger than B * T, then the batch size is divided into
126    micro-batches of size B * T as part of gradient accumulation.
127    """
128
129    micro_batch_size: int = DEFAULT_MICRO_BATCH_SIZE
130    """Batch size (micro batch) (B) used for each forward/backward pass."""
131
132    sequence_length: int = BLOCK_SIZE
133    """Sequence length (T) used for input content. Same as block_size."""
134
135    max_lr: float = 6e-4
136    """Maximum learning rate."""
137
138    min_lr_ratio: float = 0.1
139    """Minimum learning rate ratio in terms of the max learning rate."""
140
141    # 2**19 tokens per step
142    # 10e9 - 10 billion tokens / 2**19 = 19073
143    # warmup over 375 million tokens from GPT2 papager.
144    # 375e6 / 2**19 = 715 steps
145    # The warmup is very mild and could be made more aggressive
146
147    warmup_steps: int = 715
148    """Number of warmup steps before getting to the max learning rate."""
149
150    max_steps: int = 19073
151    """Total number of training steps to perform."""
152
153    eval_steps: int = 250
154    """Number of steps between each evaluation of validation loss."""
155
156    checkpoint_steps: int = 5000
157    """Number of steps between each checkpoint save."""
158
159    checkpoint_dir: str | None = None
160    """Path with a filename format string containing {step} format."""
161
162    log_file: str | None = None
163    """Path to the log file."""
164
165    def __post_init__(self) -> None:
166        """Post init."""
167        if self.total_batch_size % self.chunk_token_size != 0:
168            raise ValueError(
169                "Total batch size must be divisible by B * T"
170                f" but got {self.total_batch_size} % {self.chunk_token_size}"
171            )
172
173    @property
174    def chunk_token_size(self) -> int:
175        """Number of tokens in each micro batch."""
176        return self.micro_batch_size * self.sequence_length
177
178    @property
179    def min_lr(self) -> float:
180        """Minimum learning rate."""
181        return self.max_lr * self.min_lr_ratio
182
183    def grad_accum_steps(self, world_size: int) -> int:
184        """Number of gradient accumulation steps."""
185        return self.total_batch_size // (self.chunk_token_size * world_size)
186
187    def log_info(self, world_size: int) -> None:
188        """String representation."""
189        _LOGGER.info("Token batch size: %s", self.micro_batch_size)
190        _LOGGER.info("Sequence length: %s", self.sequence_length)
191        _LOGGER.info("Total token batch size: %s", self.total_batch_size)
192        _LOGGER.info(
193            "Gradient accumulation steps: %s", self.grad_accum_steps(world_size)
194        )
195
196
197@dataclass(frozen=True)
198class TrainedModelConfig:
199    """This class defines the configuration for the GPT model."""
200
201    model_name: str
202    """The name of the model."""
203
204    model_config: GPTConfig
205    """The configuration for the model."""
206
207    train_config: TrainConfig
208    """The configuration for the training."""
209
210
211class Models(enum.Enum):
212    """This class defines the configuration for the GPT model."""
213
214    GPT2_SMALL = TrainedModelConfig(
215        "gpt2",  # 124M params
216        GPTConfig(n_layer=12, n_head=12, n_embd=768, vocab_size=VOCAB_SIZE),
217        TrainConfig(
218            total_batch_size=2**19,  # ~0.5M, in number of tokens
219            max_lr=6e-4,
220        ),
221    )
222    GPT2_MEDIUM = TrainedModelConfig(
223        "gpt2-medium",  # 350M params
224        GPTConfig(n_layer=24, n_head=16, n_embd=1024, vocab_size=VOCAB_SIZE),
225        TrainConfig(
226            total_batch_size=2**19,  # ~0.5M, in number of tokens
227            max_lr=3e-4,
228        ),
229    )
230    GPT2_LARGE = TrainedModelConfig(
231        "gpt2-large",  # 774M params
232        GPTConfig(n_layer=36, n_head=20, n_embd=1280, vocab_size=VOCAB_SIZE),
233        TrainConfig(
234            total_batch_size=2**19,  # ~0.5M, in number of tokens
235            max_lr=2.5e-4,
236        ),
237    )
238    GPT2_XL = TrainedModelConfig(
239        "gpt2-xl",  # 1558M params
240        GPTConfig(n_layer=48, n_head=25, n_embd=1600, vocab_size=VOCAB_SIZE),
241        TrainConfig(
242            total_batch_size=2**20,  #  ~1M, in number of tokens
243            max_lr=2e-4,
244        ),
245    )
246
247    # These are model sizes that were made up for this project
248    GPT2_XS = TrainedModelConfig(
249        "gpt2-xs",  # 58M params
250        GPTConfig(n_layer=10, n_head=10, n_embd=512, vocab_size=VOCAB_SIZE),
251        TrainConfig(
252            total_batch_size=2**18,  # ~0.25M, in number of tokens
253            max_lr=3e-4,
254        ),
255    )
256    GPT2_XXS = TrainedModelConfig(
257        "gpt2-xxs",  # ~3M params
258        GPTConfig(n_layer=4, n_head=4, n_embd=64, vocab_size=VOCAB_SIZE),
259        TrainConfig(
260            total_batch_size=2**16,  # ~0.065M, in number of tokens
261            max_lr=3e-4,
262        ),
263    )
264
265
266PRETRAINED = {
267    "gpt2",
268    "gpt2-medium",
269    "gpt2-large",
270    "gpt2-xl",
271}
272MODELS = {model.value.model_name: model.value for model in Models}
273
274
275def config_from(
276    model_type: str,
277    seed: int | None = None,
278    micro_batch_size: int | None = None,
279    sequence_length: int | None = None,
280    total_batch_size: int | None = None,
281    max_steps: int | None = None,
282    eval_steps: int | None = None,
283    checkpoint_steps: int | None = None,
284    checkpoint_dir: str | None = None,
285    log_file: str | None = None,
286) -> TrainedModelConfig:
287    """Return the configuration for the model."""
288    if (config := MODELS.get(model_type)) is None:
289        raise ValueError(f"Unknown model type: {model_type}")
290    model_config_updates = {}
291    train_config_updates: dict[str, Any] = {}
292    if seed is not None:
293        train_config_updates["seed"] = seed
294    if micro_batch_size is not None:
295        train_config_updates["micro_batch_size"] = micro_batch_size
296    if sequence_length is not None:
297        train_config_updates["sequence_length"] = sequence_length
298        model_config_updates["block_size"] = sequence_length
299    if total_batch_size is not None:
300        train_config_updates["total_batch_size"] = total_batch_size
301    if max_steps is not None:
302        train_config_updates["max_steps"] = max_steps
303    if eval_steps is not None:
304        train_config_updates["eval_steps"] = eval_steps
305    if checkpoint_steps is not None:
306        train_config_updates["checkpoint_steps"] = checkpoint_steps
307    if checkpoint_dir is not None:
308        train_config_updates["checkpoint_dir"] = checkpoint_dir
309    if log_file is not None:
310        train_config_updates["log_file"] = log_file
311    return TrainedModelConfig(
312        model_name=config.model_name,
313        model_config=dataclasses.replace(
314            config.model_config,
315            **model_config_updates,
316        ),
317        train_config=dataclasses.replace(
318            config.train_config,
319            **train_config_updates,
320        ),
321    )
322
323
324def model_config_from_pretrained(model_type: str) -> GPTConfig:
325    """Return the configuration for the pretrained model."""
326    if model_type not in PRETRAINED:
327        raise ValueError(f"Unknown model type: {model_type}")
328    config = config_from(model_type)
329    return config.model_config
330
331
332def model_config_from_dict(data: dict[str, Any]) -> GPTConfig:
333    """Return the configuration for the pretrained model configuration dict."""
334    block_size = data.get("n_ctx", data.get("block_size"))
335    if not block_size:
336        raise ValueError("Missing block size in model config")
337    return GPTConfig(
338        block_size=block_size,
339        vocab_size=data["vocab_size"],
340        n_layer=data["n_layer"],
341        n_head=data["n_head"],
342        n_embd=data["n_embd"],
343    )
344
345
346class LoadDataset(Protocol):
347    """A protocol for loading a dataset."""
348
349    def __call__(self, split: str, streaming: bool = False) -> datasets.Dataset:
350        """Load a dataset."""
351
352
353@dataclass
354class TrainDataset:
355    """A dataset."""
356
357    name: str
358    """The name of the dataset."""
359
360    load_fn: LoadDataset
361    """The function to load the dataset."""
362
363    total_tokens: int
364    """The total number of tokens in the dataset."""
365
366    tokens_per_shard: int
367    """The number of tokens per shard."""
@dataclass(frozen=True, kw_only=True)
class GPTConfig:
35@dataclass(frozen=True, kw_only=True)
36class GPTConfig:
37    """This class defines the configuration for the GPT model.
38
39    This configuration is used for inference.
40    """
41
42    block_size: int = BLOCK_SIZE
43    """The maximum context length."""
44
45    vocab_size: int = NICE_VOCAB_SIZE
46    """The size of the vocabulary."""
47
48    n_layer: int = 12
49    """The number of transformer blocks."""
50
51    n_head: int = 12
52    """The number of attention heads."""
53
54    n_embd: int = 768
55    """The size of the embedding vector."""

This class defines the configuration for the GPT model.

This configuration is used for inference.

GPTConfig( *, block_size: int = 1024, vocab_size: int = 50304, n_layer: int = 12, n_head: int = 12, n_embd: int = 768)
block_size: int = 1024

The maximum context length.

vocab_size: int = 50304

The size of the vocabulary.

n_layer: int = 12

The number of transformer blocks.

n_head: int = 12

The number of attention heads.

n_embd: int = 768

The size of the embedding vector.

@dataclass(frozen=True, kw_only=True)
class DatasetConfig:
58@dataclass(frozen=True, kw_only=True)
59class DatasetConfig:
60    """This class defines the configuration for chunking the dataset."""
61
62    dataset_dir: str = "dataset_cache"
63    """The directory where the dataset is stored."""
64
65    dataset_name: str = "tinyshakespeare"
66    """The name of the dataset."""
67
68    micro_batch_size: int = DEFAULT_MICRO_BATCH_SIZE
69    """Batch size (micro batch) (B) used for each forward/backward pass."""
70
71    sequence_length: int = BLOCK_SIZE
72    """Sequence length (T) used for input content. Same as block_size."""
73
74    @property
75    def chunk_token_size(self) -> int:
76        """Number of tokens in each micro batch."""
77        return self.micro_batch_size * self.sequence_length
78
79    def dataset_path(self, split: str) -> pathlib.Path:
80        """Return the path to the dataset."""
81        dataset_dir = pathlib.Path(self.dataset_dir)
82        return dataset_dir / f"{self.dataset_name}_{split}.npy"

This class defines the configuration for chunking the dataset.

DatasetConfig( *, dataset_dir: str = 'dataset_cache', dataset_name: str = 'tinyshakespeare', micro_batch_size: int = 16, sequence_length: int = 1024)
dataset_dir: str = 'dataset_cache'

The directory where the dataset is stored.

dataset_name: str = 'tinyshakespeare'

The name of the dataset.

micro_batch_size: int = 16

Batch size (micro batch) (B) used for each forward/backward pass.

sequence_length: int = 1024

Sequence length (T) used for input content. Same as block_size.

chunk_token_size: int
74    @property
75    def chunk_token_size(self) -> int:
76        """Number of tokens in each micro batch."""
77        return self.micro_batch_size * self.sequence_length

Number of tokens in each micro batch.

def dataset_path(self, split: str) -> pathlib.Path:
79    def dataset_path(self, split: str) -> pathlib.Path:
80        """Return the path to the dataset."""
81        dataset_dir = pathlib.Path(self.dataset_dir)
82        return dataset_dir / f"{self.dataset_name}_{split}.npy"

Return the path to the dataset.

@dataclass(frozen=True, kw_only=True)
class TrainConfig:
113@dataclass(frozen=True, kw_only=True)
114class TrainConfig:
115    """Implementats the GPT-3 learning rate."""
116
117    seed: int = 1337
118    """The seed to use for training."""
119
120    step: int = 0
121    """The starting step to use for training."""
122
123    total_batch_size: int
124    """Total batch size in number of tokens for each gradient update.
125
126    If this is larger than B * T, then the batch size is divided into
127    micro-batches of size B * T as part of gradient accumulation.
128    """
129
130    micro_batch_size: int = DEFAULT_MICRO_BATCH_SIZE
131    """Batch size (micro batch) (B) used for each forward/backward pass."""
132
133    sequence_length: int = BLOCK_SIZE
134    """Sequence length (T) used for input content. Same as block_size."""
135
136    max_lr: float = 6e-4
137    """Maximum learning rate."""
138
139    min_lr_ratio: float = 0.1
140    """Minimum learning rate ratio in terms of the max learning rate."""
141
142    # 2**19 tokens per step
143    # 10e9 - 10 billion tokens / 2**19 = 19073
144    # warmup over 375 million tokens from GPT2 papager.
145    # 375e6 / 2**19 = 715 steps
146    # The warmup is very mild and could be made more aggressive
147
148    warmup_steps: int = 715
149    """Number of warmup steps before getting to the max learning rate."""
150
151    max_steps: int = 19073
152    """Total number of training steps to perform."""
153
154    eval_steps: int = 250
155    """Number of steps between each evaluation of validation loss."""
156
157    checkpoint_steps: int = 5000
158    """Number of steps between each checkpoint save."""
159
160    checkpoint_dir: str | None = None
161    """Path with a filename format string containing {step} format."""
162
163    log_file: str | None = None
164    """Path to the log file."""
165
166    def __post_init__(self) -> None:
167        """Post init."""
168        if self.total_batch_size % self.chunk_token_size != 0:
169            raise ValueError(
170                "Total batch size must be divisible by B * T"
171                f" but got {self.total_batch_size} % {self.chunk_token_size}"
172            )
173
174    @property
175    def chunk_token_size(self) -> int:
176        """Number of tokens in each micro batch."""
177        return self.micro_batch_size * self.sequence_length
178
179    @property
180    def min_lr(self) -> float:
181        """Minimum learning rate."""
182        return self.max_lr * self.min_lr_ratio
183
184    def grad_accum_steps(self, world_size: int) -> int:
185        """Number of gradient accumulation steps."""
186        return self.total_batch_size // (self.chunk_token_size * world_size)
187
188    def log_info(self, world_size: int) -> None:
189        """String representation."""
190        _LOGGER.info("Token batch size: %s", self.micro_batch_size)
191        _LOGGER.info("Sequence length: %s", self.sequence_length)
192        _LOGGER.info("Total token batch size: %s", self.total_batch_size)
193        _LOGGER.info(
194            "Gradient accumulation steps: %s", self.grad_accum_steps(world_size)
195        )

Implementats the GPT-3 learning rate.

TrainConfig( *, seed: int = 1337, step: int = 0, total_batch_size: int, micro_batch_size: int = 16, sequence_length: int = 1024, max_lr: float = 0.0006, min_lr_ratio: float = 0.1, warmup_steps: int = 715, max_steps: int = 19073, eval_steps: int = 250, checkpoint_steps: int = 5000, checkpoint_dir: str | None = None, log_file: str | None = None)
seed: int = 1337

The seed to use for training.

step: int = 0

The starting step to use for training.

total_batch_size: int

Total batch size in number of tokens for each gradient update.

If this is larger than B * T, then the batch size is divided into micro-batches of size B * T as part of gradient accumulation.

micro_batch_size: int = 16

Batch size (micro batch) (B) used for each forward/backward pass.

sequence_length: int = 1024

Sequence length (T) used for input content. Same as block_size.

max_lr: float = 0.0006

Maximum learning rate.

min_lr_ratio: float = 0.1

Minimum learning rate ratio in terms of the max learning rate.

warmup_steps: int = 715

Number of warmup steps before getting to the max learning rate.

max_steps: int = 19073

Total number of training steps to perform.

eval_steps: int = 250

Number of steps between each evaluation of validation loss.

checkpoint_steps: int = 5000

Number of steps between each checkpoint save.

checkpoint_dir: str | None = None

Path with a filename format string containing {step} format.

log_file: str | None = None

Path to the log file.

chunk_token_size: int
174    @property
175    def chunk_token_size(self) -> int:
176        """Number of tokens in each micro batch."""
177        return self.micro_batch_size * self.sequence_length

Number of tokens in each micro batch.

min_lr: float
179    @property
180    def min_lr(self) -> float:
181        """Minimum learning rate."""
182        return self.max_lr * self.min_lr_ratio

Minimum learning rate.

def grad_accum_steps(self, world_size: int) -> int:
184    def grad_accum_steps(self, world_size: int) -> int:
185        """Number of gradient accumulation steps."""
186        return self.total_batch_size // (self.chunk_token_size * world_size)

Number of gradient accumulation steps.

def log_info(self, world_size: int) -> None:
188    def log_info(self, world_size: int) -> None:
189        """String representation."""
190        _LOGGER.info("Token batch size: %s", self.micro_batch_size)
191        _LOGGER.info("Sequence length: %s", self.sequence_length)
192        _LOGGER.info("Total token batch size: %s", self.total_batch_size)
193        _LOGGER.info(
194            "Gradient accumulation steps: %s", self.grad_accum_steps(world_size)
195        )

String representation.

@dataclass(frozen=True, kw_only=True)
class SampleConfig:
85@dataclass(frozen=True, kw_only=True)
86class SampleConfig:
87    """This class defines the configuration for sampling the dataset."""
88
89    num_return_sequences: int = 5
90    """The number of sequences to generate."""
91
92    max_length: int = 30
93    """The maximum length of the generated sequences."""
94
95    text: str = "Hello, I'm a language model,"
96    """The text to use as a prompt for sampling."""
97
98    seed: int = 42
99    """The seed to use for sampling."""

This class defines the configuration for sampling the dataset.

SampleConfig( *, num_return_sequences: int = 5, max_length: int = 30, text: str = "Hello, I'm a language model,", seed: int = 42)
num_return_sequences: int = 5

The number of sequences to generate.

max_length: int = 30

The maximum length of the generated sequences.

text: str = "Hello, I'm a language model,"

The text to use as a prompt for sampling.

seed: int = 42

The seed to use for sampling.

@dataclass(frozen=True, kw_only=True)
class EvalConfig:
102@dataclass(frozen=True, kw_only=True)
103class EvalConfig:
104    """This class defines the configuration for the validation loss and HellaSwag eval."""
105
106    validation_steps: int = 20
107    """Number of validation loss iterations to perform each eval round."""
108
109    hellaswag_samples: int | None = None
110    """The number of HellaSwag evaluation results to sample or None for the entire set."""

This class defines the configuration for the validation loss and HellaSwag eval.

EvalConfig(*, validation_steps: int = 20, hellaswag_samples: int | None = None)
validation_steps: int = 20

Number of validation loss iterations to perform each eval round.

hellaswag_samples: int | None = None

The number of HellaSwag evaluation results to sample or None for the entire set.

@dataclass(frozen=True)
class TrainedModelConfig:
198@dataclass(frozen=True)
199class TrainedModelConfig:
200    """This class defines the configuration for the GPT model."""
201
202    model_name: str
203    """The name of the model."""
204
205    model_config: GPTConfig
206    """The configuration for the model."""
207
208    train_config: TrainConfig
209    """The configuration for the training."""

This class defines the configuration for the GPT model.

TrainedModelConfig( model_name: str, model_config: GPTConfig, train_config: TrainConfig)
model_name: str

The name of the model.

model_config: GPTConfig

The configuration for the model.

train_config: TrainConfig

The configuration for the training.

class Models(enum.Enum):
212class Models(enum.Enum):
213    """This class defines the configuration for the GPT model."""
214
215    GPT2_SMALL = TrainedModelConfig(
216        "gpt2",  # 124M params
217        GPTConfig(n_layer=12, n_head=12, n_embd=768, vocab_size=VOCAB_SIZE),
218        TrainConfig(
219            total_batch_size=2**19,  # ~0.5M, in number of tokens
220            max_lr=6e-4,
221        ),
222    )
223    GPT2_MEDIUM = TrainedModelConfig(
224        "gpt2-medium",  # 350M params
225        GPTConfig(n_layer=24, n_head=16, n_embd=1024, vocab_size=VOCAB_SIZE),
226        TrainConfig(
227            total_batch_size=2**19,  # ~0.5M, in number of tokens
228            max_lr=3e-4,
229        ),
230    )
231    GPT2_LARGE = TrainedModelConfig(
232        "gpt2-large",  # 774M params
233        GPTConfig(n_layer=36, n_head=20, n_embd=1280, vocab_size=VOCAB_SIZE),
234        TrainConfig(
235            total_batch_size=2**19,  # ~0.5M, in number of tokens
236            max_lr=2.5e-4,
237        ),
238    )
239    GPT2_XL = TrainedModelConfig(
240        "gpt2-xl",  # 1558M params
241        GPTConfig(n_layer=48, n_head=25, n_embd=1600, vocab_size=VOCAB_SIZE),
242        TrainConfig(
243            total_batch_size=2**20,  #  ~1M, in number of tokens
244            max_lr=2e-4,
245        ),
246    )
247
248    # These are model sizes that were made up for this project
249    GPT2_XS = TrainedModelConfig(
250        "gpt2-xs",  # 58M params
251        GPTConfig(n_layer=10, n_head=10, n_embd=512, vocab_size=VOCAB_SIZE),
252        TrainConfig(
253            total_batch_size=2**18,  # ~0.25M, in number of tokens
254            max_lr=3e-4,
255        ),
256    )
257    GPT2_XXS = TrainedModelConfig(
258        "gpt2-xxs",  # ~3M params
259        GPTConfig(n_layer=4, n_head=4, n_embd=64, vocab_size=VOCAB_SIZE),
260        TrainConfig(
261            total_batch_size=2**16,  # ~0.065M, in number of tokens
262            max_lr=3e-4,
263        ),
264    )

This class defines the configuration for the GPT model.

GPT2_SMALL = <Models.GPT2_SMALL: TrainedModelConfig(model_name='gpt2', model_config=GPTConfig(block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_embd=768), train_config=TrainConfig(seed=1337, step=0, total_batch_size=524288, micro_batch_size=16, sequence_length=1024, max_lr=0.0006, min_lr_ratio=0.1, warmup_steps=715, max_steps=19073, eval_steps=250, checkpoint_steps=5000, checkpoint_dir=None, log_file=None))>
GPT2_MEDIUM = <Models.GPT2_MEDIUM: TrainedModelConfig(model_name='gpt2-medium', model_config=GPTConfig(block_size=1024, vocab_size=50257, n_layer=24, n_head=16, n_embd=1024), train_config=TrainConfig(seed=1337, step=0, total_batch_size=524288, micro_batch_size=16, sequence_length=1024, max_lr=0.0003, min_lr_ratio=0.1, warmup_steps=715, max_steps=19073, eval_steps=250, checkpoint_steps=5000, checkpoint_dir=None, log_file=None))>
GPT2_LARGE = <Models.GPT2_LARGE: TrainedModelConfig(model_name='gpt2-large', model_config=GPTConfig(block_size=1024, vocab_size=50257, n_layer=36, n_head=20, n_embd=1280), train_config=TrainConfig(seed=1337, step=0, total_batch_size=524288, micro_batch_size=16, sequence_length=1024, max_lr=0.00025, min_lr_ratio=0.1, warmup_steps=715, max_steps=19073, eval_steps=250, checkpoint_steps=5000, checkpoint_dir=None, log_file=None))>
GPT2_XL = <Models.GPT2_XL: TrainedModelConfig(model_name='gpt2-xl', model_config=GPTConfig(block_size=1024, vocab_size=50257, n_layer=48, n_head=25, n_embd=1600), train_config=TrainConfig(seed=1337, step=0, total_batch_size=1048576, micro_batch_size=16, sequence_length=1024, max_lr=0.0002, min_lr_ratio=0.1, warmup_steps=715, max_steps=19073, eval_steps=250, checkpoint_steps=5000, checkpoint_dir=None, log_file=None))>
GPT2_XS = <Models.GPT2_XS: TrainedModelConfig(model_name='gpt2-xs', model_config=GPTConfig(block_size=1024, vocab_size=50257, n_layer=10, n_head=10, n_embd=512), train_config=TrainConfig(seed=1337, step=0, total_batch_size=262144, micro_batch_size=16, sequence_length=1024, max_lr=0.0003, min_lr_ratio=0.1, warmup_steps=715, max_steps=19073, eval_steps=250, checkpoint_steps=5000, checkpoint_dir=None, log_file=None))>
GPT2_XXS = <Models.GPT2_XXS: TrainedModelConfig(model_name='gpt2-xxs', model_config=GPTConfig(block_size=1024, vocab_size=50257, n_layer=4, n_head=4, n_embd=64), train_config=TrainConfig(seed=1337, step=0, total_batch_size=65536, micro_batch_size=16, sequence_length=1024, max_lr=0.0003, min_lr_ratio=0.1, warmup_steps=715, max_steps=19073, eval_steps=250, checkpoint_steps=5000, checkpoint_dir=None, log_file=None))>
def config_from( model_type: str, seed: int | None = None, micro_batch_size: int | None = None, sequence_length: int | None = None, total_batch_size: int | None = None, max_steps: int | None = None, eval_steps: int | None = None, checkpoint_steps: int | None = None, checkpoint_dir: str | None = None, log_file: str | None = None) -> TrainedModelConfig:
276def config_from(
277    model_type: str,
278    seed: int | None = None,
279    micro_batch_size: int | None = None,
280    sequence_length: int | None = None,
281    total_batch_size: int | None = None,
282    max_steps: int | None = None,
283    eval_steps: int | None = None,
284    checkpoint_steps: int | None = None,
285    checkpoint_dir: str | None = None,
286    log_file: str | None = None,
287) -> TrainedModelConfig:
288    """Return the configuration for the model."""
289    if (config := MODELS.get(model_type)) is None:
290        raise ValueError(f"Unknown model type: {model_type}")
291    model_config_updates = {}
292    train_config_updates: dict[str, Any] = {}
293    if seed is not None:
294        train_config_updates["seed"] = seed
295    if micro_batch_size is not None:
296        train_config_updates["micro_batch_size"] = micro_batch_size
297    if sequence_length is not None:
298        train_config_updates["sequence_length"] = sequence_length
299        model_config_updates["block_size"] = sequence_length
300    if total_batch_size is not None:
301        train_config_updates["total_batch_size"] = total_batch_size
302    if max_steps is not None:
303        train_config_updates["max_steps"] = max_steps
304    if eval_steps is not None:
305        train_config_updates["eval_steps"] = eval_steps
306    if checkpoint_steps is not None:
307        train_config_updates["checkpoint_steps"] = checkpoint_steps
308    if checkpoint_dir is not None:
309        train_config_updates["checkpoint_dir"] = checkpoint_dir
310    if log_file is not None:
311        train_config_updates["log_file"] = log_file
312    return TrainedModelConfig(
313        model_name=config.model_name,
314        model_config=dataclasses.replace(
315            config.model_config,
316            **model_config_updates,
317        ),
318        train_config=dataclasses.replace(
319            config.train_config,
320            **train_config_updates,
321        ),
322    )

Return the configuration for the model.

def model_config_from_pretrained(model_type: str) -> GPTConfig:
325def model_config_from_pretrained(model_type: str) -> GPTConfig:
326    """Return the configuration for the pretrained model."""
327    if model_type not in PRETRAINED:
328        raise ValueError(f"Unknown model type: {model_type}")
329    config = config_from(model_type)
330    return config.model_config

Return the configuration for the pretrained model.

@dataclass
class TrainDataset:
354@dataclass
355class TrainDataset:
356    """A dataset."""
357
358    name: str
359    """The name of the dataset."""
360
361    load_fn: LoadDataset
362    """The function to load the dataset."""
363
364    total_tokens: int
365    """The total number of tokens in the dataset."""
366
367    tokens_per_shard: int
368    """The number of tokens per shard."""

A dataset.

TrainDataset( name: str, load_fn: LoadDataset, total_tokens: int, tokens_per_shard: int)
name: str

The name of the dataset.

load_fn: LoadDataset

The function to load the dataset.

total_tokens: int

The total number of tokens in the dataset.

tokens_per_shard: int

The number of tokens per shard.

class LoadDataset(typing.Protocol):
347class LoadDataset(Protocol):
348    """A protocol for loading a dataset."""
349
350    def __call__(self, split: str, streaming: bool = False) -> datasets.Dataset:
351        """Load a dataset."""

A protocol for loading a dataset.

LoadDataset(*args, **kwargs)
1866def _no_init_or_replace_init(self, *args, **kwargs):
1867    cls = type(self)
1868
1869    if cls._is_protocol:
1870        raise TypeError('Protocols cannot be instantiated')
1871
1872    # Already using a custom `__init__`. No need to calculate correct
1873    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1874    if cls.__init__ is not _no_init_or_replace_init:
1875        return
1876
1877    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1878    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1879    # searches for a proper new `__init__` in the MRO. The new `__init__`
1880    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1881    # instantiation of the protocol subclass will thus use the new
1882    # `__init__` and no longer call `_no_init_or_replace_init`.
1883    for base in cls.__mro__:
1884        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1885        if init is not _no_init_or_replace_init:
1886            cls.__init__ = init
1887            break
1888    else:
1889        # should not happen
1890        cls.__init__ = object.__init__
1891
1892    cls.__init__(self, *args, **kwargs)