nano_gpt.tool.train

Command-line interface for training the model.

Usage:

usage: nano-gpt train [-h] [--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}]
                      [--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}] [--checkpoint CHECKPOINT] [--device DEVICE]
                      [--sequence-length SEQUENCE_LENGTH] [--seed SEED] [--compile | --no-compile]
                      [--total-batch-size TOTAL_BATCH_SIZE] [--streaming | --no-streaming] [--max-steps MAX_STEPS]
                      [--eval-steps EVAL_STEPS] [--checkpoint-steps CHECKPOINT_STEPS] [--checkpoint-dir CHECKPOINT_DIR]
                      [--log-file LOG_FILE] [--validation-steps VALIDATION_STEPS] [--hellaswag-samples HELLASWAG_SAMPLES]
                      [--sample-num-sequences SAMPLE_NUM_SEQUENCES] [--sample-max-length SAMPLE_MAX_LENGTH]
                      [--sample-seed SAMPLE_SEED] [--dataset {finewebedu,tinyshakespeare}] [--dataset-dir DATASET_DIR]
                      [--micro-batch-size MICRO_BATCH_SIZE]

Train a model

options:
  -h, --help            show this help message and exit
  --total-batch-size TOTAL_BATCH_SIZE
                        The number of tokens to use in each gradient accumulation batch (of micro-batches).
  --streaming, --no-streaming
                        Stream the dataset without downloading the entire corpus.
  --max-steps MAX_STEPS
                        The maximum number of training steps.
  --eval-steps EVAL_STEPS
                        The number of steps between evaluations.
  --checkpoint-steps CHECKPOINT_STEPS
                        The number of steps between checkpoints.
  --checkpoint-dir CHECKPOINT_DIR
                        The path to the checkpoint directory
  --log-file LOG_FILE   The path to the log file.
  --micro-batch-size MICRO_BATCH_SIZE
                        The number of batches of examples to pull from the dataset in each micro step.

model:
  --pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}
                        The name of the pretrained model to use.
  --model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}
                        Use the specified model name configuration default values.
  --checkpoint CHECKPOINT
                        Load a model from a checkpoint.
  --device DEVICE       The device to use.
  --sequence-length SEQUENCE_LENGTH
                        The sequence length used for input content in each micro batch.
  --seed SEED           The seed to use for sampling/training.
  --compile, --no-compile
                        Will compile the model if supported by the device.

eval:
  --validation-steps VALIDATION_STEPS
                        Number of validation loss iterations to perform each eval round.
  --hellaswag-samples HELLASWAG_SAMPLES
                        The number of HellaSwag evaluation results to sample or None for the entire set.

sample:
  --sample-num-sequences SAMPLE_NUM_SEQUENCES
                        The number of sequences to generate.
  --sample-max-length SAMPLE_MAX_LENGTH
                        The maximum length of the generated sequences.
  --sample-seed SAMPLE_SEED
                        The seed to use for sampling.

dataset:
  --dataset {finewebedu,tinyshakespeare}
                        Use the specified dataset.
  --dataset-dir DATASET_DIR
                        Directory where the dataset is stored.
  1"""Command-line interface for training the model.
  2
  3Usage:
  4```
  5usage: nano-gpt train [-h] [--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}]
  6                      [--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}] [--checkpoint CHECKPOINT] [--device DEVICE]
  7                      [--sequence-length SEQUENCE_LENGTH] [--seed SEED] [--compile | --no-compile]
  8                      [--total-batch-size TOTAL_BATCH_SIZE] [--streaming | --no-streaming] [--max-steps MAX_STEPS]
  9                      [--eval-steps EVAL_STEPS] [--checkpoint-steps CHECKPOINT_STEPS] [--checkpoint-dir CHECKPOINT_DIR]
 10                      [--log-file LOG_FILE] [--validation-steps VALIDATION_STEPS] [--hellaswag-samples HELLASWAG_SAMPLES]
 11                      [--sample-num-sequences SAMPLE_NUM_SEQUENCES] [--sample-max-length SAMPLE_MAX_LENGTH]
 12                      [--sample-seed SAMPLE_SEED] [--dataset {finewebedu,tinyshakespeare}] [--dataset-dir DATASET_DIR]
 13                      [--micro-batch-size MICRO_BATCH_SIZE]
 14
 15Train a model
 16
 17options:
 18  -h, --help            show this help message and exit
 19  --total-batch-size TOTAL_BATCH_SIZE
 20                        The number of tokens to use in each gradient accumulation batch (of micro-batches).
 21  --streaming, --no-streaming
 22                        Stream the dataset without downloading the entire corpus.
 23  --max-steps MAX_STEPS
 24                        The maximum number of training steps.
 25  --eval-steps EVAL_STEPS
 26                        The number of steps between evaluations.
 27  --checkpoint-steps CHECKPOINT_STEPS
 28                        The number of steps between checkpoints.
 29  --checkpoint-dir CHECKPOINT_DIR
 30                        The path to the checkpoint directory
 31  --log-file LOG_FILE   The path to the log file.
 32  --micro-batch-size MICRO_BATCH_SIZE
 33                        The number of batches of examples to pull from the dataset in each micro step.
 34
 35model:
 36  --pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}
 37                        The name of the pretrained model to use.
 38  --model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}
 39                        Use the specified model name configuration default values.
 40  --checkpoint CHECKPOINT
 41                        Load a model from a checkpoint.
 42  --device DEVICE       The device to use.
 43  --sequence-length SEQUENCE_LENGTH
 44                        The sequence length used for input content in each micro batch.
 45  --seed SEED           The seed to use for sampling/training.
 46  --compile, --no-compile
 47                        Will compile the model if supported by the device.
 48
 49eval:
 50  --validation-steps VALIDATION_STEPS
 51                        Number of validation loss iterations to perform each eval round.
 52  --hellaswag-samples HELLASWAG_SAMPLES
 53                        The number of HellaSwag evaluation results to sample or None for the entire set.
 54
 55sample:
 56  --sample-num-sequences SAMPLE_NUM_SEQUENCES
 57                        The number of sequences to generate.
 58  --sample-max-length SAMPLE_MAX_LENGTH
 59                        The maximum length of the generated sequences.
 60  --sample-seed SAMPLE_SEED
 61                        The seed to use for sampling.
 62
 63dataset:
 64  --dataset {finewebedu,tinyshakespeare}
 65                        Use the specified dataset.
 66  --dataset-dir DATASET_DIR
 67                        Directory where the dataset is stored.
 68```
 69"""
 70
 71import argparse
 72import logging
 73from collections.abc import Iterable
 74import datetime
 75
 76import torch
 77
 78from nano_gpt.datasets.data_loader import read_preprocessed_corpus
 79from nano_gpt.datasets import hellaswag
 80from nano_gpt.trainer import train, WorkerState
 81from nano_gpt.checkpoint import CHECKPOINT_DIR
 82from nano_gpt.trainer import create_optimizer
 83from .model_config import (
 84    create_model_arguments,
 85    create_eval_arguments,
 86    create_sample_arguments,
 87    create_dataset_arguments,
 88    dataset_config_from_args,
 89    eval_config_from_args,
 90    model_from_args,
 91    sample_config_from_args,
 92    load_checkpoint_context,
 93)
 94
 95
 96_LOGGER = logging.getLogger(__name__)
 97
 98
 99def create_arguments(args: argparse.ArgumentParser) -> None:
100    """Get parsed passed in arguments."""
101    create_model_arguments(args)
102    args.add_argument(
103        "--total-batch-size",
104        type=int,
105        help="The number of tokens to use in each gradient accumulation batch (of micro-batches).",
106    )
107    args.add_argument(
108        "--streaming",
109        type=str,
110        action=argparse.BooleanOptionalAction,
111        default=False,
112        help="Stream the dataset without downloading the entire corpus.",
113    )
114    args.add_argument(
115        "--max-steps",
116        type=int,
117        default=None,
118        help="The maximum number of training steps.",
119    )
120    args.add_argument(
121        "--eval-steps",
122        type=int,
123        default=250,
124        help="The number of steps between evaluations.",
125    )
126    args.add_argument(
127        "--checkpoint-steps",
128        type=int,
129        default=None,
130        help="The number of steps between checkpoints.",
131    )
132    args.add_argument(
133        "--checkpoint-dir",
134        type=str,
135        default=str(CHECKPOINT_DIR),
136        help="The path to the checkpoint directory",
137    )
138    args.add_argument(
139        "--log-file",
140        type=str,
141        default="train_{now}.log".format(
142            now=datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
143        ),
144        help="The path to the log file.",
145    )
146
147    create_eval_arguments(args)
148    create_sample_arguments(args)
149    create_dataset_arguments(args)
150
151
152def run(args: argparse.Namespace) -> int:
153    """Run the sample command."""
154    torch.set_float32_matmul_precision("high")
155
156    with load_checkpoint_context(args) as checkpoint:
157        model, tokenizer, config = model_from_args(args, checkpoint)
158        if config is None:
159            raise ValueError("No trainable model configuration found")
160        eval_config = eval_config_from_args(args, checkpoint)
161        _LOGGER.info(f"Eval config: {eval_config}")
162        sample_config = sample_config_from_args(args, checkpoint)
163        _LOGGER.info(f"Sample config: {sample_config}")
164        dataset_config = dataset_config_from_args(args, checkpoint)
165        if dataset_config.dataset_name is None:
166            raise ValueError("Required flag --dataset is missing")
167        _LOGGER.info(f"Dataset config: {dataset_config}")
168
169        worker_state = WorkerState(args.device)
170        _LOGGER.info("Worker state: %s", worker_state)
171
172        optimizer = create_optimizer(
173            model,
174            config.train_config,
175            checkpoint,
176            worker_state.is_cuda,
177        )
178
179    _LOGGER.info("Loading dataset %s (streaming=%s)", args.dataset, args.streaming)
180    train_data_loader = read_preprocessed_corpus(
181        dataset_config.dataset_path("train"),
182        dataset_config,
183        worker_num=worker_state.ddp_rank,
184        worker_count=worker_state.ddp_world_size,
185    )
186    val_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]] | None = None
187    hellaswag_val: Iterable[hellaswag.Sample] | None = None
188    if eval_config.validation_steps:
189        val_data_loader = read_preprocessed_corpus(
190            dataset_config.dataset_path("validation"),
191            dataset_config,
192            worker_num=worker_state.ddp_rank,
193            worker_count=worker_state.ddp_world_size,
194        )
195    if eval_config.hellaswag_samples is not None:
196        hellaswag_val = hellaswag.load_dataset("validation")
197    _LOGGER.info("Dataset loaded")
198    train(
199        model,
200        optimizer,
201        worker_state,
202        config.train_config,
203        train_data_loader=train_data_loader,
204        eval_config=eval_config,
205        dataset_config=dataset_config,
206        val_data_loader=val_data_loader,
207        hellaswag_loader=hellaswag_val,
208        sample_config=sample_config,
209    )
210    return 0
def create_arguments(args: argparse.ArgumentParser) -> None:
100def create_arguments(args: argparse.ArgumentParser) -> None:
101    """Get parsed passed in arguments."""
102    create_model_arguments(args)
103    args.add_argument(
104        "--total-batch-size",
105        type=int,
106        help="The number of tokens to use in each gradient accumulation batch (of micro-batches).",
107    )
108    args.add_argument(
109        "--streaming",
110        type=str,
111        action=argparse.BooleanOptionalAction,
112        default=False,
113        help="Stream the dataset without downloading the entire corpus.",
114    )
115    args.add_argument(
116        "--max-steps",
117        type=int,
118        default=None,
119        help="The maximum number of training steps.",
120    )
121    args.add_argument(
122        "--eval-steps",
123        type=int,
124        default=250,
125        help="The number of steps between evaluations.",
126    )
127    args.add_argument(
128        "--checkpoint-steps",
129        type=int,
130        default=None,
131        help="The number of steps between checkpoints.",
132    )
133    args.add_argument(
134        "--checkpoint-dir",
135        type=str,
136        default=str(CHECKPOINT_DIR),
137        help="The path to the checkpoint directory",
138    )
139    args.add_argument(
140        "--log-file",
141        type=str,
142        default="train_{now}.log".format(
143            now=datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
144        ),
145        help="The path to the log file.",
146    )
147
148    create_eval_arguments(args)
149    create_sample_arguments(args)
150    create_dataset_arguments(args)

Get parsed passed in arguments.

def run(args: argparse.Namespace) -> int:
153def run(args: argparse.Namespace) -> int:
154    """Run the sample command."""
155    torch.set_float32_matmul_precision("high")
156
157    with load_checkpoint_context(args) as checkpoint:
158        model, tokenizer, config = model_from_args(args, checkpoint)
159        if config is None:
160            raise ValueError("No trainable model configuration found")
161        eval_config = eval_config_from_args(args, checkpoint)
162        _LOGGER.info(f"Eval config: {eval_config}")
163        sample_config = sample_config_from_args(args, checkpoint)
164        _LOGGER.info(f"Sample config: {sample_config}")
165        dataset_config = dataset_config_from_args(args, checkpoint)
166        if dataset_config.dataset_name is None:
167            raise ValueError("Required flag --dataset is missing")
168        _LOGGER.info(f"Dataset config: {dataset_config}")
169
170        worker_state = WorkerState(args.device)
171        _LOGGER.info("Worker state: %s", worker_state)
172
173        optimizer = create_optimizer(
174            model,
175            config.train_config,
176            checkpoint,
177            worker_state.is_cuda,
178        )
179
180    _LOGGER.info("Loading dataset %s (streaming=%s)", args.dataset, args.streaming)
181    train_data_loader = read_preprocessed_corpus(
182        dataset_config.dataset_path("train"),
183        dataset_config,
184        worker_num=worker_state.ddp_rank,
185        worker_count=worker_state.ddp_world_size,
186    )
187    val_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]] | None = None
188    hellaswag_val: Iterable[hellaswag.Sample] | None = None
189    if eval_config.validation_steps:
190        val_data_loader = read_preprocessed_corpus(
191            dataset_config.dataset_path("validation"),
192            dataset_config,
193            worker_num=worker_state.ddp_rank,
194            worker_count=worker_state.ddp_world_size,
195        )
196    if eval_config.hellaswag_samples is not None:
197        hellaswag_val = hellaswag.load_dataset("validation")
198    _LOGGER.info("Dataset loaded")
199    train(
200        model,
201        optimizer,
202        worker_state,
203        config.train_config,
204        train_data_loader=train_data_loader,
205        eval_config=eval_config,
206        dataset_config=dataset_config,
207        val_data_loader=val_data_loader,
208        hellaswag_loader=hellaswag_val,
209        sample_config=sample_config,
210    )
211    return 0

Run the sample command.