nano_gpt.tool.model_config

Shared library for command line flags for loading models.

This module provides functions for creating and parsing command line arguments for loading models, as well as functions for converting these arguments into model configurations.

This can be used to load a model from a checkpoint, a pretrained model, or initialize a model from pre-defined model configuration from the GPT-2 paper.

  1"""Shared library for command line flags for loading models.
  2
  3This module provides functions for creating and parsing command line arguments for
  4loading models, as well as functions for converting these arguments into model
  5configurations.
  6
  7This can be used to load a model from a checkpoint, a pretrained model, or
  8initialize a model from pre-defined model configuration from the GPT-2 paper.
  9"""
 10
 11from argparse import ArgumentParser, BooleanOptionalAction
 12from collections.abc import Generator
 13from contextlib import contextmanager
 14import dataclasses
 15import json
 16import logging
 17import pathlib
 18from typing import Any, cast
 19
 20import torch
 21from huggingface_hub import HfFileSystem
 22
 23from nano_gpt.checkpoint import load_checkpoint, Checkpoint
 24from nano_gpt.config import (
 25    MODELS,
 26    config_from,
 27    TrainedModelConfig,
 28    EvalConfig,
 29    SampleConfig,
 30    DatasetConfig,
 31    model_config_from_pretrained,
 32    model_config_from_dict,
 33)
 34from nano_gpt.datasets import TRAIN_DATASETS
 35from nano_gpt.devices import get_device
 36from nano_gpt.model import GPT
 37from nano_gpt.tokenizer import get_tokenizer, Tokenizer
 38
 39_LOGGER = logging.getLogger(__name__)
 40
 41DATASET_DIR = "dataset_cache"
 42
 43
 44def create_model_arguments(
 45    args: ArgumentParser, default_values: dict[str, Any] | None = None
 46) -> None:
 47    """Create arguments for model loading."""
 48    if default_values is None:
 49        default_values = {}
 50    group = args.add_argument_group("model")
 51    group.add_argument(
 52        "--pretrained",
 53        type=str,
 54        help="The name of the pretrained model to use.",
 55    )
 56    group.add_argument(
 57        "--model",
 58        type=str,
 59        default=default_values.get("model", "gpt2"),
 60        choices=sorted(MODELS),
 61        help="Use the specified model name configuration default values.",
 62    )
 63    group.add_argument(
 64        "--checkpoint",
 65        type=str,
 66        help="Load a model from a checkpoint.",
 67    )
 68    group.add_argument(
 69        "--device",
 70        type=str,
 71        help="The device to use.",
 72    )
 73    group.add_argument(
 74        "--sequence-length",
 75        type=int,
 76        help="The sequence length used for input content in each micro batch.",
 77    )
 78    group.add_argument(
 79        "--seed",
 80        type=int,
 81        help="The seed to use for sampling/training.",
 82    )
 83    group.add_argument(
 84        "--compile",
 85        type=str,
 86        action=BooleanOptionalAction,
 87        default=True,
 88        help="Will compile the model if supported by the device.",
 89    )
 90
 91
 92def _check_model_arguments(args: Any) -> None:
 93    """Check that the model arguments are valid."""
 94    if args.pretrained is None and args.checkpoint is None and args.model is None:
 95        raise ValueError(
 96            "Either --pretrained or --checkpoint or --model must be specified"
 97        )
 98
 99
100def model_config_from_args(
101    args: Any,
102) -> TrainedModelConfig:
103    """Create a model from the flags."""
104    return config_from(
105        args.model,
106        **{
107            key: value
108            for key in {"micro_batch_size", "sequence_length", "total_batch_size"}
109            if (value := getattr(args, key, None)) is not None
110        },
111    )
112
113
114@contextmanager
115def load_checkpoint_context(args: Any) -> Generator[Checkpoint | None, None, None]:
116    """Load a checkpoint from the flags.
117
118    This is a context manager so that the checkpoint can be used across multiple calls to
119    parse arguments, but then discarded after the checkpoint is no longer needed.
120    """
121    if args.checkpoint is not None:
122        checkpoint_path = pathlib.Path(args.checkpoint)
123        _LOGGER.info("Restoring from checkpoint: %s", checkpoint_path)
124        yield load_checkpoint(checkpoint_path, args.device)
125    else:
126        yield None
127
128
129def _trained_model_config_dict_from_args(args: Any) -> dict[str, Any]:
130    """Create a TrainedModelConfig parameter dict from flags."""
131    return {
132        key: value
133        for key in {
134            "seed",
135            "micro_batch_size",
136            "sequence_length",
137            "total_batch_size",
138            "max_steps",
139            "eval_steps",
140            "eval_num_samples",
141            "checkpoint_steps",
142            "checkpoint_dir",
143            "log_file",
144        }
145        if (value := getattr(args, key, None)) is not None
146    }
147
148
149def model_from_args(
150    args: Any, checkpoint: Checkpoint | None
151) -> tuple[GPT, Tokenizer, TrainedModelConfig | None]:
152    """Create a model from the flags."""
153    _check_model_arguments(args)
154    tokenizer = get_tokenizer()
155    trained_model_config: TrainedModelConfig | None = None
156    if args.pretrained is not None:
157        if checkpoint is not None:
158            raise ValueError("Cannot specify both --pretrained and --checkpoint")
159        _LOGGER.info("loading weights from pretrained gpt: %s" % args.pretrained)
160        pretrained_args: dict[str, Any] = {}
161        if args.pretrained.startswith("./") or args.pretrained.startswith("/"):
162            # If the pretrained model is a local path, we need to load it from the local
163            local_path = pathlib.Path(args.pretrained)
164            model_config_path = local_path / "config.json"
165            _LOGGER.info("Loading model config from %s", model_config_path)
166            data = json.loads(model_config_path.read_text())
167            model_config = model_config_from_dict(data)
168        elif args.pretrained in MODELS:
169            _LOGGER.info("Loading known model config: %s", args.pretrained)
170            model_config = model_config_from_pretrained(args.pretrained)
171        else:
172            fs = HfFileSystem()
173            model_config_path = pathlib.Path(args.pretrained) / "/config.json"
174            _LOGGER.info("Loading model config from %s", model_config_path)
175            data = json.loads(fs.read_text(str(model_config_path)))
176            model_config = model_config_from_dict(data)
177        _LOGGER.info("Initializing model from pretrained config: %s", model_config)
178        model = GPT.from_pretrained(
179            args.pretrained,
180            tokenizer=tokenizer,
181            model_config=model_config,
182            **pretrained_args,
183        )
184    elif checkpoint is not None:
185        _LOGGER.debug("initializing model from checkpoint: %s", checkpoint.config)
186        model = GPT(checkpoint.config, tokenizer=tokenizer)
187        model.load_state_dict(checkpoint.model_state_dict_for_inference)
188        model_config = checkpoint.config
189        train_config = dataclasses.replace(
190            checkpoint.train_config,
191            **_trained_model_config_dict_from_args(args),
192        )
193        trained_model_config = TrainedModelConfig(
194            model_name=checkpoint.name or "checkpoint",
195            model_config=checkpoint.config,
196            train_config=train_config,
197        )
198    else:
199        trained_model_config = config_from(
200            args.model,
201            **_trained_model_config_dict_from_args(args),
202        )
203        model_config = trained_model_config.model_config
204        _LOGGER.debug("initializing model from config: %s", model_config)
205        model = GPT(model_config, tokenizer=tokenizer)
206    _LOGGER.info("Trained model config: %s", trained_model_config)
207    if args.device is None:
208        args.device = get_device()
209    # TODO: Fix compilation with DDP
210    if args.device == "cuda":
211        if args.compile:
212            _LOGGER.info("Compiling model")
213            try:
214                model = cast(GPT, torch.compile(model))
215            except RuntimeError as err:
216                raise RuntimeError(
217                    f"Failed to compile model, try with --no-compile: {err}"
218                ) from err
219        else:
220            _LOGGER.debug("Not compiling model")
221    else:
222        _LOGGER.debug("Model will not be compiled (%s)", args.device)
223
224    seed: int | None = None
225    if (
226        trained_model_config is not None
227        and trained_model_config.train_config.seed is not None
228    ):
229        seed = trained_model_config.train_config.seed
230    if args.seed is not None:
231        seed = args.seed
232
233    if seed is not None:
234        _LOGGER.info("Setting seed to %s", seed)
235        torch.manual_seed(seed)
236        torch.cuda.manual_seed(seed)
237
238    return model, tokenizer, trained_model_config
239
240
241def create_eval_arguments(args: ArgumentParser) -> None:
242    """Create arguments for model evaluation."""
243    group = args.add_argument_group("eval")
244    group.add_argument(
245        "--validation-steps",
246        type=int,
247        help="Number of validation loss iterations to perform each eval round.",
248    )
249    group.add_argument(
250        "--hellaswag-samples",
251        type=int,
252        help="The number of HellaSwag evaluation results to sample or None for the entire set.",
253    )
254
255
256def eval_config_from_args(args: Any, checkpoint: Checkpoint | None) -> EvalConfig:
257    """Create an EvalConfig from the flags."""
258    values = {}
259    if args.validation_steps is not None:
260        values["validation_steps"] = args.validation_steps
261    if args.hellaswag_samples is not None:
262        values["hellaswag_samples"] = args.hellaswag_samples
263    if checkpoint is not None and checkpoint.eval_config is not None:
264        return dataclasses.replace(
265            checkpoint.eval_config,
266            **values,
267        )
268    return EvalConfig(**values)
269
270
271def create_sample_arguments(args: ArgumentParser) -> None:
272    """Create arguments for model sampling."""
273    group = args.add_argument_group("sample")
274    group.add_argument(
275        "--sample-num-sequences",
276        type=int,
277        help="The number of sequences to generate.",
278    )
279    group.add_argument(
280        "--sample-max-length",
281        type=int,
282        help="The maximum length of the generated sequences.",
283    )
284    group.add_argument(
285        "--sample-seed",
286        type=int,
287        help="The seed to use for sampling.",
288    )
289
290
291def sample_config_from_args(args: Any, checkpoint: Checkpoint | None) -> SampleConfig:
292    """Create an SampleConfig from the flags."""
293    values = {}
294    if args.sample_num_sequences is not None:
295        values["num_return_sequences"] = args.sample_num_sequences
296    if args.sample_max_length is not None:
297        values["max_length"] = args.sample_max_length
298    if args.sample_seed is not None:
299        values["seed"] = args.sample_seed
300    if checkpoint is not None and checkpoint.sample_config is not None:
301        return dataclasses.replace(
302            checkpoint.sample_config,
303            **values,
304        )
305    return SampleConfig(**values)
306
307
308def create_dataset_arguments(args: ArgumentParser) -> None:
309    """Create arguments for dataset loading."""
310    group = args.add_argument_group("dataset")
311    group.add_argument(
312        "--dataset",
313        type=str,
314        help="Use the specified dataset.",
315        choices=TRAIN_DATASETS.keys(),
316        required=False,
317    )
318    group.add_argument(
319        "--dataset-dir",
320        type=str,
321        help="Directory where the dataset is stored.",
322        default=DATASET_DIR,
323    )
324    args.add_argument(
325        "--micro-batch-size",
326        type=int,
327        help="The number of batches of examples to pull from the dataset in each micro step.",
328    )
329
330
331def dataset_config_from_args(args: Any, checkpoint: Checkpoint | None) -> DatasetConfig:
332    """Create a DatasetConfig from the flags."""
333    values = {}
334    if args.dataset is not None:
335        values["dataset_name"] = args.dataset
336    if args.dataset_dir is not None:
337        values["dataset_dir"] = args.dataset_dir
338    if args.micro_batch_size is not None:
339        values["micro_batch_size"] = args.micro_batch_size
340    if args.sequence_length is not None:
341        values["sequence_length"] = args.sequence_length
342    if checkpoint is not None and checkpoint.dataset_config is not None:
343        return dataclasses.replace(
344            checkpoint.dataset_config,
345            **values,
346        )
347    return DatasetConfig(**values)
DATASET_DIR = 'dataset_cache'
def create_model_arguments( args: argparse.ArgumentParser, default_values: dict[str, Any] | None = None) -> None:
45def create_model_arguments(
46    args: ArgumentParser, default_values: dict[str, Any] | None = None
47) -> None:
48    """Create arguments for model loading."""
49    if default_values is None:
50        default_values = {}
51    group = args.add_argument_group("model")
52    group.add_argument(
53        "--pretrained",
54        type=str,
55        help="The name of the pretrained model to use.",
56    )
57    group.add_argument(
58        "--model",
59        type=str,
60        default=default_values.get("model", "gpt2"),
61        choices=sorted(MODELS),
62        help="Use the specified model name configuration default values.",
63    )
64    group.add_argument(
65        "--checkpoint",
66        type=str,
67        help="Load a model from a checkpoint.",
68    )
69    group.add_argument(
70        "--device",
71        type=str,
72        help="The device to use.",
73    )
74    group.add_argument(
75        "--sequence-length",
76        type=int,
77        help="The sequence length used for input content in each micro batch.",
78    )
79    group.add_argument(
80        "--seed",
81        type=int,
82        help="The seed to use for sampling/training.",
83    )
84    group.add_argument(
85        "--compile",
86        type=str,
87        action=BooleanOptionalAction,
88        default=True,
89        help="Will compile the model if supported by the device.",
90    )

Create arguments for model loading.

def model_config_from_args(args: Any) -> nano_gpt.config.TrainedModelConfig:
101def model_config_from_args(
102    args: Any,
103) -> TrainedModelConfig:
104    """Create a model from the flags."""
105    return config_from(
106        args.model,
107        **{
108            key: value
109            for key in {"micro_batch_size", "sequence_length", "total_batch_size"}
110            if (value := getattr(args, key, None)) is not None
111        },
112    )

Create a model from the flags.

@contextmanager
def load_checkpoint_context( args: Any) -> Generator[nano_gpt.checkpoint.Checkpoint | None, None, None]:
115@contextmanager
116def load_checkpoint_context(args: Any) -> Generator[Checkpoint | None, None, None]:
117    """Load a checkpoint from the flags.
118
119    This is a context manager so that the checkpoint can be used across multiple calls to
120    parse arguments, but then discarded after the checkpoint is no longer needed.
121    """
122    if args.checkpoint is not None:
123        checkpoint_path = pathlib.Path(args.checkpoint)
124        _LOGGER.info("Restoring from checkpoint: %s", checkpoint_path)
125        yield load_checkpoint(checkpoint_path, args.device)
126    else:
127        yield None

Load a checkpoint from the flags.

This is a context manager so that the checkpoint can be used across multiple calls to parse arguments, but then discarded after the checkpoint is no longer needed.

def model_from_args( args: Any, checkpoint: nano_gpt.checkpoint.Checkpoint | None) -> tuple[nano_gpt.model.GPT, nano_gpt.tokenizer.Tokenizer, nano_gpt.config.TrainedModelConfig | None]:
150def model_from_args(
151    args: Any, checkpoint: Checkpoint | None
152) -> tuple[GPT, Tokenizer, TrainedModelConfig | None]:
153    """Create a model from the flags."""
154    _check_model_arguments(args)
155    tokenizer = get_tokenizer()
156    trained_model_config: TrainedModelConfig | None = None
157    if args.pretrained is not None:
158        if checkpoint is not None:
159            raise ValueError("Cannot specify both --pretrained and --checkpoint")
160        _LOGGER.info("loading weights from pretrained gpt: %s" % args.pretrained)
161        pretrained_args: dict[str, Any] = {}
162        if args.pretrained.startswith("./") or args.pretrained.startswith("/"):
163            # If the pretrained model is a local path, we need to load it from the local
164            local_path = pathlib.Path(args.pretrained)
165            model_config_path = local_path / "config.json"
166            _LOGGER.info("Loading model config from %s", model_config_path)
167            data = json.loads(model_config_path.read_text())
168            model_config = model_config_from_dict(data)
169        elif args.pretrained in MODELS:
170            _LOGGER.info("Loading known model config: %s", args.pretrained)
171            model_config = model_config_from_pretrained(args.pretrained)
172        else:
173            fs = HfFileSystem()
174            model_config_path = pathlib.Path(args.pretrained) / "/config.json"
175            _LOGGER.info("Loading model config from %s", model_config_path)
176            data = json.loads(fs.read_text(str(model_config_path)))
177            model_config = model_config_from_dict(data)
178        _LOGGER.info("Initializing model from pretrained config: %s", model_config)
179        model = GPT.from_pretrained(
180            args.pretrained,
181            tokenizer=tokenizer,
182            model_config=model_config,
183            **pretrained_args,
184        )
185    elif checkpoint is not None:
186        _LOGGER.debug("initializing model from checkpoint: %s", checkpoint.config)
187        model = GPT(checkpoint.config, tokenizer=tokenizer)
188        model.load_state_dict(checkpoint.model_state_dict_for_inference)
189        model_config = checkpoint.config
190        train_config = dataclasses.replace(
191            checkpoint.train_config,
192            **_trained_model_config_dict_from_args(args),
193        )
194        trained_model_config = TrainedModelConfig(
195            model_name=checkpoint.name or "checkpoint",
196            model_config=checkpoint.config,
197            train_config=train_config,
198        )
199    else:
200        trained_model_config = config_from(
201            args.model,
202            **_trained_model_config_dict_from_args(args),
203        )
204        model_config = trained_model_config.model_config
205        _LOGGER.debug("initializing model from config: %s", model_config)
206        model = GPT(model_config, tokenizer=tokenizer)
207    _LOGGER.info("Trained model config: %s", trained_model_config)
208    if args.device is None:
209        args.device = get_device()
210    # TODO: Fix compilation with DDP
211    if args.device == "cuda":
212        if args.compile:
213            _LOGGER.info("Compiling model")
214            try:
215                model = cast(GPT, torch.compile(model))
216            except RuntimeError as err:
217                raise RuntimeError(
218                    f"Failed to compile model, try with --no-compile: {err}"
219                ) from err
220        else:
221            _LOGGER.debug("Not compiling model")
222    else:
223        _LOGGER.debug("Model will not be compiled (%s)", args.device)
224
225    seed: int | None = None
226    if (
227        trained_model_config is not None
228        and trained_model_config.train_config.seed is not None
229    ):
230        seed = trained_model_config.train_config.seed
231    if args.seed is not None:
232        seed = args.seed
233
234    if seed is not None:
235        _LOGGER.info("Setting seed to %s", seed)
236        torch.manual_seed(seed)
237        torch.cuda.manual_seed(seed)
238
239    return model, tokenizer, trained_model_config

Create a model from the flags.

def create_eval_arguments(args: argparse.ArgumentParser) -> None:
242def create_eval_arguments(args: ArgumentParser) -> None:
243    """Create arguments for model evaluation."""
244    group = args.add_argument_group("eval")
245    group.add_argument(
246        "--validation-steps",
247        type=int,
248        help="Number of validation loss iterations to perform each eval round.",
249    )
250    group.add_argument(
251        "--hellaswag-samples",
252        type=int,
253        help="The number of HellaSwag evaluation results to sample or None for the entire set.",
254    )

Create arguments for model evaluation.

def eval_config_from_args( args: Any, checkpoint: nano_gpt.checkpoint.Checkpoint | None) -> nano_gpt.config.EvalConfig:
257def eval_config_from_args(args: Any, checkpoint: Checkpoint | None) -> EvalConfig:
258    """Create an EvalConfig from the flags."""
259    values = {}
260    if args.validation_steps is not None:
261        values["validation_steps"] = args.validation_steps
262    if args.hellaswag_samples is not None:
263        values["hellaswag_samples"] = args.hellaswag_samples
264    if checkpoint is not None and checkpoint.eval_config is not None:
265        return dataclasses.replace(
266            checkpoint.eval_config,
267            **values,
268        )
269    return EvalConfig(**values)

Create an EvalConfig from the flags.

def create_sample_arguments(args: argparse.ArgumentParser) -> None:
272def create_sample_arguments(args: ArgumentParser) -> None:
273    """Create arguments for model sampling."""
274    group = args.add_argument_group("sample")
275    group.add_argument(
276        "--sample-num-sequences",
277        type=int,
278        help="The number of sequences to generate.",
279    )
280    group.add_argument(
281        "--sample-max-length",
282        type=int,
283        help="The maximum length of the generated sequences.",
284    )
285    group.add_argument(
286        "--sample-seed",
287        type=int,
288        help="The seed to use for sampling.",
289    )

Create arguments for model sampling.

def sample_config_from_args( args: Any, checkpoint: nano_gpt.checkpoint.Checkpoint | None) -> nano_gpt.config.SampleConfig:
292def sample_config_from_args(args: Any, checkpoint: Checkpoint | None) -> SampleConfig:
293    """Create an SampleConfig from the flags."""
294    values = {}
295    if args.sample_num_sequences is not None:
296        values["num_return_sequences"] = args.sample_num_sequences
297    if args.sample_max_length is not None:
298        values["max_length"] = args.sample_max_length
299    if args.sample_seed is not None:
300        values["seed"] = args.sample_seed
301    if checkpoint is not None and checkpoint.sample_config is not None:
302        return dataclasses.replace(
303            checkpoint.sample_config,
304            **values,
305        )
306    return SampleConfig(**values)

Create an SampleConfig from the flags.

def create_dataset_arguments(args: argparse.ArgumentParser) -> None:
309def create_dataset_arguments(args: ArgumentParser) -> None:
310    """Create arguments for dataset loading."""
311    group = args.add_argument_group("dataset")
312    group.add_argument(
313        "--dataset",
314        type=str,
315        help="Use the specified dataset.",
316        choices=TRAIN_DATASETS.keys(),
317        required=False,
318    )
319    group.add_argument(
320        "--dataset-dir",
321        type=str,
322        help="Directory where the dataset is stored.",
323        default=DATASET_DIR,
324    )
325    args.add_argument(
326        "--micro-batch-size",
327        type=int,
328        help="The number of batches of examples to pull from the dataset in each micro step.",
329    )

Create arguments for dataset loading.

def dataset_config_from_args( args: Any, checkpoint: nano_gpt.checkpoint.Checkpoint | None) -> nano_gpt.config.DatasetConfig:
332def dataset_config_from_args(args: Any, checkpoint: Checkpoint | None) -> DatasetConfig:
333    """Create a DatasetConfig from the flags."""
334    values = {}
335    if args.dataset is not None:
336        values["dataset_name"] = args.dataset
337    if args.dataset_dir is not None:
338        values["dataset_dir"] = args.dataset_dir
339    if args.micro_batch_size is not None:
340        values["micro_batch_size"] = args.micro_batch_size
341    if args.sequence_length is not None:
342        values["sequence_length"] = args.sequence_length
343    if checkpoint is not None and checkpoint.dataset_config is not None:
344        return dataclasses.replace(
345            checkpoint.dataset_config,
346            **values,
347        )
348    return DatasetConfig(**values)

Create a DatasetConfig from the flags.