nano_gpt.tool.train
Command-line interface for training the model.
Usage:
usage: nano-gpt train [-h] [--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}]
[--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}] [--checkpoint CHECKPOINT] [--device DEVICE]
[--sequence-length SEQUENCE_LENGTH] [--seed SEED] [--compile | --no-compile]
[--total-batch-size TOTAL_BATCH_SIZE] [--streaming | --no-streaming] [--max-steps MAX_STEPS]
[--eval-steps EVAL_STEPS] [--checkpoint-steps CHECKPOINT_STEPS] [--checkpoint-dir CHECKPOINT_DIR]
[--log-file LOG_FILE] [--validation-steps VALIDATION_STEPS] [--hellaswag-samples HELLASWAG_SAMPLES]
[--sample-num-sequences SAMPLE_NUM_SEQUENCES] [--sample-max-length SAMPLE_MAX_LENGTH]
[--sample-seed SAMPLE_SEED] [--dataset {finewebedu,tinyshakespeare}] [--dataset-dir DATASET_DIR]
[--micro-batch-size MICRO_BATCH_SIZE]
Train a model
options:
-h, --help show this help message and exit
--total-batch-size TOTAL_BATCH_SIZE
The number of tokens to use in each gradient accumulation batch (of micro-batches).
--streaming, --no-streaming
Stream the dataset without downloading the entire corpus.
--max-steps MAX_STEPS
The maximum number of training steps.
--eval-steps EVAL_STEPS
The number of steps between evaluations.
--checkpoint-steps CHECKPOINT_STEPS
The number of steps between checkpoints.
--checkpoint-dir CHECKPOINT_DIR
The path to the checkpoint directory
--log-file LOG_FILE The path to the log file.
--micro-batch-size MICRO_BATCH_SIZE
The number of batches of examples to pull from the dataset in each micro step.
model:
--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}
The name of the pretrained model to use.
--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}
Use the specified model name configuration default values.
--checkpoint CHECKPOINT
Load a model from a checkpoint.
--device DEVICE The device to use.
--sequence-length SEQUENCE_LENGTH
The sequence length used for input content in each micro batch.
--seed SEED The seed to use for sampling/training.
--compile, --no-compile
Will compile the model if supported by the device.
eval:
--validation-steps VALIDATION_STEPS
Number of validation loss iterations to perform each eval round.
--hellaswag-samples HELLASWAG_SAMPLES
The number of HellaSwag evaluation results to sample or None for the entire set.
sample:
--sample-num-sequences SAMPLE_NUM_SEQUENCES
The number of sequences to generate.
--sample-max-length SAMPLE_MAX_LENGTH
The maximum length of the generated sequences.
--sample-seed SAMPLE_SEED
The seed to use for sampling.
dataset:
--dataset {finewebedu,tinyshakespeare}
Use the specified dataset.
--dataset-dir DATASET_DIR
Directory where the dataset is stored.
1"""Command-line interface for training the model. 2 3Usage: 4``` 5usage: nano-gpt train [-h] [--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}] 6 [--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}] [--checkpoint CHECKPOINT] [--device DEVICE] 7 [--sequence-length SEQUENCE_LENGTH] [--seed SEED] [--compile | --no-compile] 8 [--total-batch-size TOTAL_BATCH_SIZE] [--streaming | --no-streaming] [--max-steps MAX_STEPS] 9 [--eval-steps EVAL_STEPS] [--checkpoint-steps CHECKPOINT_STEPS] [--checkpoint-dir CHECKPOINT_DIR] 10 [--log-file LOG_FILE] [--validation-steps VALIDATION_STEPS] [--hellaswag-samples HELLASWAG_SAMPLES] 11 [--sample-num-sequences SAMPLE_NUM_SEQUENCES] [--sample-max-length SAMPLE_MAX_LENGTH] 12 [--sample-seed SAMPLE_SEED] [--dataset {finewebedu,tinyshakespeare}] [--dataset-dir DATASET_DIR] 13 [--micro-batch-size MICRO_BATCH_SIZE] 14 15Train a model 16 17options: 18 -h, --help show this help message and exit 19 --total-batch-size TOTAL_BATCH_SIZE 20 The number of tokens to use in each gradient accumulation batch (of micro-batches). 21 --streaming, --no-streaming 22 Stream the dataset without downloading the entire corpus. 23 --max-steps MAX_STEPS 24 The maximum number of training steps. 25 --eval-steps EVAL_STEPS 26 The number of steps between evaluations. 27 --checkpoint-steps CHECKPOINT_STEPS 28 The number of steps between checkpoints. 29 --checkpoint-dir CHECKPOINT_DIR 30 The path to the checkpoint directory 31 --log-file LOG_FILE The path to the log file. 32 --micro-batch-size MICRO_BATCH_SIZE 33 The number of batches of examples to pull from the dataset in each micro step. 34 35model: 36 --pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl} 37 The name of the pretrained model to use. 38 --model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs} 39 Use the specified model name configuration default values. 40 --checkpoint CHECKPOINT 41 Load a model from a checkpoint. 42 --device DEVICE The device to use. 43 --sequence-length SEQUENCE_LENGTH 44 The sequence length used for input content in each micro batch. 45 --seed SEED The seed to use for sampling/training. 46 --compile, --no-compile 47 Will compile the model if supported by the device. 48 49eval: 50 --validation-steps VALIDATION_STEPS 51 Number of validation loss iterations to perform each eval round. 52 --hellaswag-samples HELLASWAG_SAMPLES 53 The number of HellaSwag evaluation results to sample or None for the entire set. 54 55sample: 56 --sample-num-sequences SAMPLE_NUM_SEQUENCES 57 The number of sequences to generate. 58 --sample-max-length SAMPLE_MAX_LENGTH 59 The maximum length of the generated sequences. 60 --sample-seed SAMPLE_SEED 61 The seed to use for sampling. 62 63dataset: 64 --dataset {finewebedu,tinyshakespeare} 65 Use the specified dataset. 66 --dataset-dir DATASET_DIR 67 Directory where the dataset is stored. 68``` 69""" 70 71import argparse 72import logging 73from collections.abc import Iterable 74import datetime 75 76import torch 77 78from nano_gpt.datasets.data_loader import read_preprocessed_corpus 79from nano_gpt.datasets import hellaswag 80from nano_gpt.trainer import train, WorkerState 81from nano_gpt.checkpoint import CHECKPOINT_DIR 82from nano_gpt.trainer import create_optimizer 83from .model_config import ( 84 create_model_arguments, 85 create_eval_arguments, 86 create_sample_arguments, 87 create_dataset_arguments, 88 dataset_config_from_args, 89 eval_config_from_args, 90 model_from_args, 91 sample_config_from_args, 92 load_checkpoint_context, 93) 94 95 96_LOGGER = logging.getLogger(__name__) 97 98 99def create_arguments(args: argparse.ArgumentParser) -> None: 100 """Get parsed passed in arguments.""" 101 create_model_arguments(args) 102 args.add_argument( 103 "--total-batch-size", 104 type=int, 105 help="The number of tokens to use in each gradient accumulation batch (of micro-batches).", 106 ) 107 args.add_argument( 108 "--streaming", 109 type=str, 110 action=argparse.BooleanOptionalAction, 111 default=False, 112 help="Stream the dataset without downloading the entire corpus.", 113 ) 114 args.add_argument( 115 "--max-steps", 116 type=int, 117 default=None, 118 help="The maximum number of training steps.", 119 ) 120 args.add_argument( 121 "--eval-steps", 122 type=int, 123 default=250, 124 help="The number of steps between evaluations.", 125 ) 126 args.add_argument( 127 "--checkpoint-steps", 128 type=int, 129 default=None, 130 help="The number of steps between checkpoints.", 131 ) 132 args.add_argument( 133 "--checkpoint-dir", 134 type=str, 135 default=str(CHECKPOINT_DIR), 136 help="The path to the checkpoint directory", 137 ) 138 args.add_argument( 139 "--log-file", 140 type=str, 141 default="train_{now}.log".format( 142 now=datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 143 ), 144 help="The path to the log file.", 145 ) 146 147 create_eval_arguments(args) 148 create_sample_arguments(args) 149 create_dataset_arguments(args) 150 151 152def run(args: argparse.Namespace) -> int: 153 """Run the sample command.""" 154 torch.set_float32_matmul_precision("high") 155 156 with load_checkpoint_context(args) as checkpoint: 157 model, tokenizer, config = model_from_args(args, checkpoint) 158 if config is None: 159 raise ValueError("No trainable model configuration found") 160 eval_config = eval_config_from_args(args, checkpoint) 161 _LOGGER.info(f"Eval config: {eval_config}") 162 sample_config = sample_config_from_args(args, checkpoint) 163 _LOGGER.info(f"Sample config: {sample_config}") 164 dataset_config = dataset_config_from_args(args, checkpoint) 165 if dataset_config.dataset_name is None: 166 raise ValueError("Required flag --dataset is missing") 167 _LOGGER.info(f"Dataset config: {dataset_config}") 168 169 worker_state = WorkerState(args.device) 170 _LOGGER.info("Worker state: %s", worker_state) 171 172 optimizer = create_optimizer( 173 model, 174 config.train_config, 175 checkpoint, 176 worker_state.is_cuda, 177 ) 178 179 _LOGGER.info("Loading dataset %s (streaming=%s)", args.dataset, args.streaming) 180 train_data_loader = read_preprocessed_corpus( 181 dataset_config.dataset_path("train"), 182 dataset_config, 183 worker_num=worker_state.ddp_rank, 184 worker_count=worker_state.ddp_world_size, 185 ) 186 val_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]] | None = None 187 hellaswag_val: Iterable[hellaswag.Sample] | None = None 188 if eval_config.validation_steps: 189 val_data_loader = read_preprocessed_corpus( 190 dataset_config.dataset_path("validation"), 191 dataset_config, 192 worker_num=worker_state.ddp_rank, 193 worker_count=worker_state.ddp_world_size, 194 ) 195 if eval_config.hellaswag_samples is not None: 196 hellaswag_val = hellaswag.load_dataset("validation") 197 _LOGGER.info("Dataset loaded") 198 train( 199 model, 200 optimizer, 201 worker_state, 202 config.train_config, 203 train_data_loader=train_data_loader, 204 eval_config=eval_config, 205 dataset_config=dataset_config, 206 val_data_loader=val_data_loader, 207 hellaswag_loader=hellaswag_val, 208 sample_config=sample_config, 209 ) 210 return 0
def
create_arguments(args: argparse.ArgumentParser) -> None:
100def create_arguments(args: argparse.ArgumentParser) -> None: 101 """Get parsed passed in arguments.""" 102 create_model_arguments(args) 103 args.add_argument( 104 "--total-batch-size", 105 type=int, 106 help="The number of tokens to use in each gradient accumulation batch (of micro-batches).", 107 ) 108 args.add_argument( 109 "--streaming", 110 type=str, 111 action=argparse.BooleanOptionalAction, 112 default=False, 113 help="Stream the dataset without downloading the entire corpus.", 114 ) 115 args.add_argument( 116 "--max-steps", 117 type=int, 118 default=None, 119 help="The maximum number of training steps.", 120 ) 121 args.add_argument( 122 "--eval-steps", 123 type=int, 124 default=250, 125 help="The number of steps between evaluations.", 126 ) 127 args.add_argument( 128 "--checkpoint-steps", 129 type=int, 130 default=None, 131 help="The number of steps between checkpoints.", 132 ) 133 args.add_argument( 134 "--checkpoint-dir", 135 type=str, 136 default=str(CHECKPOINT_DIR), 137 help="The path to the checkpoint directory", 138 ) 139 args.add_argument( 140 "--log-file", 141 type=str, 142 default="train_{now}.log".format( 143 now=datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 144 ), 145 help="The path to the log file.", 146 ) 147 148 create_eval_arguments(args) 149 create_sample_arguments(args) 150 create_dataset_arguments(args)
Get parsed passed in arguments.
def
run(args: argparse.Namespace) -> int:
153def run(args: argparse.Namespace) -> int: 154 """Run the sample command.""" 155 torch.set_float32_matmul_precision("high") 156 157 with load_checkpoint_context(args) as checkpoint: 158 model, tokenizer, config = model_from_args(args, checkpoint) 159 if config is None: 160 raise ValueError("No trainable model configuration found") 161 eval_config = eval_config_from_args(args, checkpoint) 162 _LOGGER.info(f"Eval config: {eval_config}") 163 sample_config = sample_config_from_args(args, checkpoint) 164 _LOGGER.info(f"Sample config: {sample_config}") 165 dataset_config = dataset_config_from_args(args, checkpoint) 166 if dataset_config.dataset_name is None: 167 raise ValueError("Required flag --dataset is missing") 168 _LOGGER.info(f"Dataset config: {dataset_config}") 169 170 worker_state = WorkerState(args.device) 171 _LOGGER.info("Worker state: %s", worker_state) 172 173 optimizer = create_optimizer( 174 model, 175 config.train_config, 176 checkpoint, 177 worker_state.is_cuda, 178 ) 179 180 _LOGGER.info("Loading dataset %s (streaming=%s)", args.dataset, args.streaming) 181 train_data_loader = read_preprocessed_corpus( 182 dataset_config.dataset_path("train"), 183 dataset_config, 184 worker_num=worker_state.ddp_rank, 185 worker_count=worker_state.ddp_world_size, 186 ) 187 val_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]] | None = None 188 hellaswag_val: Iterable[hellaswag.Sample] | None = None 189 if eval_config.validation_steps: 190 val_data_loader = read_preprocessed_corpus( 191 dataset_config.dataset_path("validation"), 192 dataset_config, 193 worker_num=worker_state.ddp_rank, 194 worker_count=worker_state.ddp_world_size, 195 ) 196 if eval_config.hellaswag_samples is not None: 197 hellaswag_val = hellaswag.load_dataset("validation") 198 _LOGGER.info("Dataset loaded") 199 train( 200 model, 201 optimizer, 202 worker_state, 203 config.train_config, 204 train_data_loader=train_data_loader, 205 eval_config=eval_config, 206 dataset_config=dataset_config, 207 val_data_loader=val_data_loader, 208 hellaswag_loader=hellaswag_val, 209 sample_config=sample_config, 210 ) 211 return 0
Run the sample command.