nano_gpt.tool.model_config
Shared library for command line flags for loading models.
This module provides functions for creating and parsing command line arguments for loading models, as well as functions for converting these arguments into model configurations.
This can be used to load a model from a checkpoint, a pretrained model, or initialize a model from pre-defined model configuration from the GPT-2 paper.
1"""Shared library for command line flags for loading models. 2 3This module provides functions for creating and parsing command line arguments for 4loading models, as well as functions for converting these arguments into model 5configurations. 6 7This can be used to load a model from a checkpoint, a pretrained model, or 8initialize a model from pre-defined model configuration from the GPT-2 paper. 9""" 10 11from argparse import ArgumentParser, BooleanOptionalAction 12from collections.abc import Generator 13from contextlib import contextmanager 14import dataclasses 15import json 16import logging 17import pathlib 18from typing import Any, cast 19 20import torch 21from huggingface_hub import HfFileSystem 22 23from nano_gpt.checkpoint import load_checkpoint, Checkpoint 24from nano_gpt.config import ( 25 MODELS, 26 config_from, 27 TrainedModelConfig, 28 EvalConfig, 29 SampleConfig, 30 DatasetConfig, 31 model_config_from_pretrained, 32 model_config_from_dict, 33) 34from nano_gpt.datasets import TRAIN_DATASETS 35from nano_gpt.devices import get_device 36from nano_gpt.model import GPT 37from nano_gpt.tokenizer import get_tokenizer, Tokenizer 38 39_LOGGER = logging.getLogger(__name__) 40 41DATASET_DIR = "dataset_cache" 42 43 44def create_model_arguments( 45 args: ArgumentParser, default_values: dict[str, Any] | None = None 46) -> None: 47 """Create arguments for model loading.""" 48 if default_values is None: 49 default_values = {} 50 group = args.add_argument_group("model") 51 group.add_argument( 52 "--pretrained", 53 type=str, 54 help="The name of the pretrained model to use.", 55 ) 56 group.add_argument( 57 "--model", 58 type=str, 59 default=default_values.get("model", "gpt2"), 60 choices=sorted(MODELS), 61 help="Use the specified model name configuration default values.", 62 ) 63 group.add_argument( 64 "--checkpoint", 65 type=str, 66 help="Load a model from a checkpoint.", 67 ) 68 group.add_argument( 69 "--device", 70 type=str, 71 help="The device to use.", 72 ) 73 group.add_argument( 74 "--sequence-length", 75 type=int, 76 help="The sequence length used for input content in each micro batch.", 77 ) 78 group.add_argument( 79 "--seed", 80 type=int, 81 help="The seed to use for sampling/training.", 82 ) 83 group.add_argument( 84 "--compile", 85 type=str, 86 action=BooleanOptionalAction, 87 default=True, 88 help="Will compile the model if supported by the device.", 89 ) 90 91 92def _check_model_arguments(args: Any) -> None: 93 """Check that the model arguments are valid.""" 94 if args.pretrained is None and args.checkpoint is None and args.model is None: 95 raise ValueError( 96 "Either --pretrained or --checkpoint or --model must be specified" 97 ) 98 99 100def model_config_from_args( 101 args: Any, 102) -> TrainedModelConfig: 103 """Create a model from the flags.""" 104 return config_from( 105 args.model, 106 **{ 107 key: value 108 for key in {"micro_batch_size", "sequence_length", "total_batch_size"} 109 if (value := getattr(args, key, None)) is not None 110 }, 111 ) 112 113 114@contextmanager 115def load_checkpoint_context(args: Any) -> Generator[Checkpoint | None, None, None]: 116 """Load a checkpoint from the flags. 117 118 This is a context manager so that the checkpoint can be used across multiple calls to 119 parse arguments, but then discarded after the checkpoint is no longer needed. 120 """ 121 if args.checkpoint is not None: 122 checkpoint_path = pathlib.Path(args.checkpoint) 123 _LOGGER.info("Restoring from checkpoint: %s", checkpoint_path) 124 yield load_checkpoint(checkpoint_path, args.device) 125 else: 126 yield None 127 128 129def _trained_model_config_dict_from_args(args: Any) -> dict[str, Any]: 130 """Create a TrainedModelConfig parameter dict from flags.""" 131 return { 132 key: value 133 for key in { 134 "seed", 135 "micro_batch_size", 136 "sequence_length", 137 "total_batch_size", 138 "max_steps", 139 "eval_steps", 140 "eval_num_samples", 141 "checkpoint_steps", 142 "checkpoint_dir", 143 "log_file", 144 } 145 if (value := getattr(args, key, None)) is not None 146 } 147 148 149def model_from_args( 150 args: Any, checkpoint: Checkpoint | None 151) -> tuple[GPT, Tokenizer, TrainedModelConfig | None]: 152 """Create a model from the flags.""" 153 _check_model_arguments(args) 154 tokenizer = get_tokenizer() 155 trained_model_config: TrainedModelConfig | None = None 156 if args.pretrained is not None: 157 if checkpoint is not None: 158 raise ValueError("Cannot specify both --pretrained and --checkpoint") 159 _LOGGER.info("loading weights from pretrained gpt: %s" % args.pretrained) 160 pretrained_args: dict[str, Any] = {} 161 if args.pretrained.startswith("./") or args.pretrained.startswith("/"): 162 # If the pretrained model is a local path, we need to load it from the local 163 local_path = pathlib.Path(args.pretrained) 164 model_config_path = local_path / "config.json" 165 _LOGGER.info("Loading model config from %s", model_config_path) 166 data = json.loads(model_config_path.read_text()) 167 model_config = model_config_from_dict(data) 168 elif args.pretrained in MODELS: 169 _LOGGER.info("Loading known model config: %s", args.pretrained) 170 model_config = model_config_from_pretrained(args.pretrained) 171 else: 172 fs = HfFileSystem() 173 model_config_path = pathlib.Path(args.pretrained) / "/config.json" 174 _LOGGER.info("Loading model config from %s", model_config_path) 175 data = json.loads(fs.read_text(str(model_config_path))) 176 model_config = model_config_from_dict(data) 177 _LOGGER.info("Initializing model from pretrained config: %s", model_config) 178 model = GPT.from_pretrained( 179 args.pretrained, 180 tokenizer=tokenizer, 181 model_config=model_config, 182 **pretrained_args, 183 ) 184 elif checkpoint is not None: 185 _LOGGER.debug("initializing model from checkpoint: %s", checkpoint.config) 186 model = GPT(checkpoint.config, tokenizer=tokenizer) 187 model.load_state_dict(checkpoint.model_state_dict_for_inference) 188 model_config = checkpoint.config 189 train_config = dataclasses.replace( 190 checkpoint.train_config, 191 **_trained_model_config_dict_from_args(args), 192 ) 193 trained_model_config = TrainedModelConfig( 194 model_name=checkpoint.name or "checkpoint", 195 model_config=checkpoint.config, 196 train_config=train_config, 197 ) 198 else: 199 trained_model_config = config_from( 200 args.model, 201 **_trained_model_config_dict_from_args(args), 202 ) 203 model_config = trained_model_config.model_config 204 _LOGGER.debug("initializing model from config: %s", model_config) 205 model = GPT(model_config, tokenizer=tokenizer) 206 _LOGGER.info("Trained model config: %s", trained_model_config) 207 if args.device is None: 208 args.device = get_device() 209 # TODO: Fix compilation with DDP 210 if args.device == "cuda": 211 if args.compile: 212 _LOGGER.info("Compiling model") 213 try: 214 model = cast(GPT, torch.compile(model)) 215 except RuntimeError as err: 216 raise RuntimeError( 217 f"Failed to compile model, try with --no-compile: {err}" 218 ) from err 219 else: 220 _LOGGER.debug("Not compiling model") 221 else: 222 _LOGGER.debug("Model will not be compiled (%s)", args.device) 223 224 seed: int | None = None 225 if ( 226 trained_model_config is not None 227 and trained_model_config.train_config.seed is not None 228 ): 229 seed = trained_model_config.train_config.seed 230 if args.seed is not None: 231 seed = args.seed 232 233 if seed is not None: 234 _LOGGER.info("Setting seed to %s", seed) 235 torch.manual_seed(seed) 236 torch.cuda.manual_seed(seed) 237 238 return model, tokenizer, trained_model_config 239 240 241def create_eval_arguments(args: ArgumentParser) -> None: 242 """Create arguments for model evaluation.""" 243 group = args.add_argument_group("eval") 244 group.add_argument( 245 "--validation-steps", 246 type=int, 247 help="Number of validation loss iterations to perform each eval round.", 248 ) 249 group.add_argument( 250 "--hellaswag-samples", 251 type=int, 252 help="The number of HellaSwag evaluation results to sample or None for the entire set.", 253 ) 254 255 256def eval_config_from_args(args: Any, checkpoint: Checkpoint | None) -> EvalConfig: 257 """Create an EvalConfig from the flags.""" 258 values = {} 259 if args.validation_steps is not None: 260 values["validation_steps"] = args.validation_steps 261 if args.hellaswag_samples is not None: 262 values["hellaswag_samples"] = args.hellaswag_samples 263 if checkpoint is not None and checkpoint.eval_config is not None: 264 return dataclasses.replace( 265 checkpoint.eval_config, 266 **values, 267 ) 268 return EvalConfig(**values) 269 270 271def create_sample_arguments(args: ArgumentParser) -> None: 272 """Create arguments for model sampling.""" 273 group = args.add_argument_group("sample") 274 group.add_argument( 275 "--sample-num-sequences", 276 type=int, 277 help="The number of sequences to generate.", 278 ) 279 group.add_argument( 280 "--sample-max-length", 281 type=int, 282 help="The maximum length of the generated sequences.", 283 ) 284 group.add_argument( 285 "--sample-seed", 286 type=int, 287 help="The seed to use for sampling.", 288 ) 289 290 291def sample_config_from_args(args: Any, checkpoint: Checkpoint | None) -> SampleConfig: 292 """Create an SampleConfig from the flags.""" 293 values = {} 294 if args.sample_num_sequences is not None: 295 values["num_return_sequences"] = args.sample_num_sequences 296 if args.sample_max_length is not None: 297 values["max_length"] = args.sample_max_length 298 if args.sample_seed is not None: 299 values["seed"] = args.sample_seed 300 if checkpoint is not None and checkpoint.sample_config is not None: 301 return dataclasses.replace( 302 checkpoint.sample_config, 303 **values, 304 ) 305 return SampleConfig(**values) 306 307 308def create_dataset_arguments(args: ArgumentParser) -> None: 309 """Create arguments for dataset loading.""" 310 group = args.add_argument_group("dataset") 311 group.add_argument( 312 "--dataset", 313 type=str, 314 help="Use the specified dataset.", 315 choices=TRAIN_DATASETS.keys(), 316 required=False, 317 ) 318 group.add_argument( 319 "--dataset-dir", 320 type=str, 321 help="Directory where the dataset is stored.", 322 default=DATASET_DIR, 323 ) 324 args.add_argument( 325 "--micro-batch-size", 326 type=int, 327 help="The number of batches of examples to pull from the dataset in each micro step.", 328 ) 329 330 331def dataset_config_from_args(args: Any, checkpoint: Checkpoint | None) -> DatasetConfig: 332 """Create a DatasetConfig from the flags.""" 333 values = {} 334 if args.dataset is not None: 335 values["dataset_name"] = args.dataset 336 if args.dataset_dir is not None: 337 values["dataset_dir"] = args.dataset_dir 338 if args.micro_batch_size is not None: 339 values["micro_batch_size"] = args.micro_batch_size 340 if args.sequence_length is not None: 341 values["sequence_length"] = args.sequence_length 342 if checkpoint is not None and checkpoint.dataset_config is not None: 343 return dataclasses.replace( 344 checkpoint.dataset_config, 345 **values, 346 ) 347 return DatasetConfig(**values)
45def create_model_arguments( 46 args: ArgumentParser, default_values: dict[str, Any] | None = None 47) -> None: 48 """Create arguments for model loading.""" 49 if default_values is None: 50 default_values = {} 51 group = args.add_argument_group("model") 52 group.add_argument( 53 "--pretrained", 54 type=str, 55 help="The name of the pretrained model to use.", 56 ) 57 group.add_argument( 58 "--model", 59 type=str, 60 default=default_values.get("model", "gpt2"), 61 choices=sorted(MODELS), 62 help="Use the specified model name configuration default values.", 63 ) 64 group.add_argument( 65 "--checkpoint", 66 type=str, 67 help="Load a model from a checkpoint.", 68 ) 69 group.add_argument( 70 "--device", 71 type=str, 72 help="The device to use.", 73 ) 74 group.add_argument( 75 "--sequence-length", 76 type=int, 77 help="The sequence length used for input content in each micro batch.", 78 ) 79 group.add_argument( 80 "--seed", 81 type=int, 82 help="The seed to use for sampling/training.", 83 ) 84 group.add_argument( 85 "--compile", 86 type=str, 87 action=BooleanOptionalAction, 88 default=True, 89 help="Will compile the model if supported by the device.", 90 )
Create arguments for model loading.
101def model_config_from_args( 102 args: Any, 103) -> TrainedModelConfig: 104 """Create a model from the flags.""" 105 return config_from( 106 args.model, 107 **{ 108 key: value 109 for key in {"micro_batch_size", "sequence_length", "total_batch_size"} 110 if (value := getattr(args, key, None)) is not None 111 }, 112 )
Create a model from the flags.
115@contextmanager 116def load_checkpoint_context(args: Any) -> Generator[Checkpoint | None, None, None]: 117 """Load a checkpoint from the flags. 118 119 This is a context manager so that the checkpoint can be used across multiple calls to 120 parse arguments, but then discarded after the checkpoint is no longer needed. 121 """ 122 if args.checkpoint is not None: 123 checkpoint_path = pathlib.Path(args.checkpoint) 124 _LOGGER.info("Restoring from checkpoint: %s", checkpoint_path) 125 yield load_checkpoint(checkpoint_path, args.device) 126 else: 127 yield None
Load a checkpoint from the flags.
This is a context manager so that the checkpoint can be used across multiple calls to parse arguments, but then discarded after the checkpoint is no longer needed.
150def model_from_args( 151 args: Any, checkpoint: Checkpoint | None 152) -> tuple[GPT, Tokenizer, TrainedModelConfig | None]: 153 """Create a model from the flags.""" 154 _check_model_arguments(args) 155 tokenizer = get_tokenizer() 156 trained_model_config: TrainedModelConfig | None = None 157 if args.pretrained is not None: 158 if checkpoint is not None: 159 raise ValueError("Cannot specify both --pretrained and --checkpoint") 160 _LOGGER.info("loading weights from pretrained gpt: %s" % args.pretrained) 161 pretrained_args: dict[str, Any] = {} 162 if args.pretrained.startswith("./") or args.pretrained.startswith("/"): 163 # If the pretrained model is a local path, we need to load it from the local 164 local_path = pathlib.Path(args.pretrained) 165 model_config_path = local_path / "config.json" 166 _LOGGER.info("Loading model config from %s", model_config_path) 167 data = json.loads(model_config_path.read_text()) 168 model_config = model_config_from_dict(data) 169 elif args.pretrained in MODELS: 170 _LOGGER.info("Loading known model config: %s", args.pretrained) 171 model_config = model_config_from_pretrained(args.pretrained) 172 else: 173 fs = HfFileSystem() 174 model_config_path = pathlib.Path(args.pretrained) / "/config.json" 175 _LOGGER.info("Loading model config from %s", model_config_path) 176 data = json.loads(fs.read_text(str(model_config_path))) 177 model_config = model_config_from_dict(data) 178 _LOGGER.info("Initializing model from pretrained config: %s", model_config) 179 model = GPT.from_pretrained( 180 args.pretrained, 181 tokenizer=tokenizer, 182 model_config=model_config, 183 **pretrained_args, 184 ) 185 elif checkpoint is not None: 186 _LOGGER.debug("initializing model from checkpoint: %s", checkpoint.config) 187 model = GPT(checkpoint.config, tokenizer=tokenizer) 188 model.load_state_dict(checkpoint.model_state_dict_for_inference) 189 model_config = checkpoint.config 190 train_config = dataclasses.replace( 191 checkpoint.train_config, 192 **_trained_model_config_dict_from_args(args), 193 ) 194 trained_model_config = TrainedModelConfig( 195 model_name=checkpoint.name or "checkpoint", 196 model_config=checkpoint.config, 197 train_config=train_config, 198 ) 199 else: 200 trained_model_config = config_from( 201 args.model, 202 **_trained_model_config_dict_from_args(args), 203 ) 204 model_config = trained_model_config.model_config 205 _LOGGER.debug("initializing model from config: %s", model_config) 206 model = GPT(model_config, tokenizer=tokenizer) 207 _LOGGER.info("Trained model config: %s", trained_model_config) 208 if args.device is None: 209 args.device = get_device() 210 # TODO: Fix compilation with DDP 211 if args.device == "cuda": 212 if args.compile: 213 _LOGGER.info("Compiling model") 214 try: 215 model = cast(GPT, torch.compile(model)) 216 except RuntimeError as err: 217 raise RuntimeError( 218 f"Failed to compile model, try with --no-compile: {err}" 219 ) from err 220 else: 221 _LOGGER.debug("Not compiling model") 222 else: 223 _LOGGER.debug("Model will not be compiled (%s)", args.device) 224 225 seed: int | None = None 226 if ( 227 trained_model_config is not None 228 and trained_model_config.train_config.seed is not None 229 ): 230 seed = trained_model_config.train_config.seed 231 if args.seed is not None: 232 seed = args.seed 233 234 if seed is not None: 235 _LOGGER.info("Setting seed to %s", seed) 236 torch.manual_seed(seed) 237 torch.cuda.manual_seed(seed) 238 239 return model, tokenizer, trained_model_config
Create a model from the flags.
242def create_eval_arguments(args: ArgumentParser) -> None: 243 """Create arguments for model evaluation.""" 244 group = args.add_argument_group("eval") 245 group.add_argument( 246 "--validation-steps", 247 type=int, 248 help="Number of validation loss iterations to perform each eval round.", 249 ) 250 group.add_argument( 251 "--hellaswag-samples", 252 type=int, 253 help="The number of HellaSwag evaluation results to sample or None for the entire set.", 254 )
Create arguments for model evaluation.
257def eval_config_from_args(args: Any, checkpoint: Checkpoint | None) -> EvalConfig: 258 """Create an EvalConfig from the flags.""" 259 values = {} 260 if args.validation_steps is not None: 261 values["validation_steps"] = args.validation_steps 262 if args.hellaswag_samples is not None: 263 values["hellaswag_samples"] = args.hellaswag_samples 264 if checkpoint is not None and checkpoint.eval_config is not None: 265 return dataclasses.replace( 266 checkpoint.eval_config, 267 **values, 268 ) 269 return EvalConfig(**values)
Create an EvalConfig from the flags.
272def create_sample_arguments(args: ArgumentParser) -> None: 273 """Create arguments for model sampling.""" 274 group = args.add_argument_group("sample") 275 group.add_argument( 276 "--sample-num-sequences", 277 type=int, 278 help="The number of sequences to generate.", 279 ) 280 group.add_argument( 281 "--sample-max-length", 282 type=int, 283 help="The maximum length of the generated sequences.", 284 ) 285 group.add_argument( 286 "--sample-seed", 287 type=int, 288 help="The seed to use for sampling.", 289 )
Create arguments for model sampling.
292def sample_config_from_args(args: Any, checkpoint: Checkpoint | None) -> SampleConfig: 293 """Create an SampleConfig from the flags.""" 294 values = {} 295 if args.sample_num_sequences is not None: 296 values["num_return_sequences"] = args.sample_num_sequences 297 if args.sample_max_length is not None: 298 values["max_length"] = args.sample_max_length 299 if args.sample_seed is not None: 300 values["seed"] = args.sample_seed 301 if checkpoint is not None and checkpoint.sample_config is not None: 302 return dataclasses.replace( 303 checkpoint.sample_config, 304 **values, 305 ) 306 return SampleConfig(**values)
Create an SampleConfig from the flags.
309def create_dataset_arguments(args: ArgumentParser) -> None: 310 """Create arguments for dataset loading.""" 311 group = args.add_argument_group("dataset") 312 group.add_argument( 313 "--dataset", 314 type=str, 315 help="Use the specified dataset.", 316 choices=TRAIN_DATASETS.keys(), 317 required=False, 318 ) 319 group.add_argument( 320 "--dataset-dir", 321 type=str, 322 help="Directory where the dataset is stored.", 323 default=DATASET_DIR, 324 ) 325 args.add_argument( 326 "--micro-batch-size", 327 type=int, 328 help="The number of batches of examples to pull from the dataset in each micro step.", 329 )
Create arguments for dataset loading.
332def dataset_config_from_args(args: Any, checkpoint: Checkpoint | None) -> DatasetConfig: 333 """Create a DatasetConfig from the flags.""" 334 values = {} 335 if args.dataset is not None: 336 values["dataset_name"] = args.dataset 337 if args.dataset_dir is not None: 338 values["dataset_dir"] = args.dataset_dir 339 if args.micro_batch_size is not None: 340 values["micro_batch_size"] = args.micro_batch_size 341 if args.sequence_length is not None: 342 values["sequence_length"] = args.sequence_length 343 if checkpoint is not None and checkpoint.dataset_config is not None: 344 return dataclasses.replace( 345 checkpoint.dataset_config, 346 **values, 347 ) 348 return DatasetConfig(**values)
Create a DatasetConfig from the flags.