nano_gpt.tool.eval

Command-line interface for evaling a trained model.

Usage:

usage: nano-gpt eval [-h] [--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}]
                     [--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}] [--checkpoint CHECKPOINT] [--device DEVICE]
                     [--sequence-length SEQUENCE_LENGTH] [--seed SEED] [--compile | --no-compile]
                     [--validation-steps VALIDATION_STEPS] [--hellaswag-samples HELLASWAG_SAMPLES]
                     [--dataset {finewebedu,tinyshakespeare}] [--dataset-dir DATASET_DIR] [--micro-batch-size MICRO_BATCH_SIZE]

Evaluate a model

options:
  -h, --help            show this help message and exit
  --micro-batch-size MICRO_BATCH_SIZE
                        The number of batches of examples to pull from the dataset in each micro step.

model:
  --pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}
                        The name of the pretrained model to use.
  --model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}
                        Use the specified model name configuration default values.
  --checkpoint CHECKPOINT
                        Load a model from a checkpoint.
  --device DEVICE       The device to use.
  --sequence-length SEQUENCE_LENGTH
                        The sequence length used for input content in each micro batch.
  --seed SEED           The seed to use for sampling/training.
  --compile, --no-compile
                        Will compile the model if supported by the device.

eval:
  --validation-steps VALIDATION_STEPS
                        Number of validation loss iterations to perform each eval round.
  --hellaswag-samples HELLASWAG_SAMPLES
                        The number of HellaSwag evaluation results to sample or None for the entire set.

dataset:
  --dataset {finewebedu,tinyshakespeare}
                        Use the specified dataset.
  --dataset-dir DATASET_DIR
                        Directory where the dataset is stored.
  1"""Command-line interface for evaling a trained model.
  2
  3Usage:
  4```
  5usage: nano-gpt eval [-h] [--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}]
  6                     [--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}] [--checkpoint CHECKPOINT] [--device DEVICE]
  7                     [--sequence-length SEQUENCE_LENGTH] [--seed SEED] [--compile | --no-compile]
  8                     [--validation-steps VALIDATION_STEPS] [--hellaswag-samples HELLASWAG_SAMPLES]
  9                     [--dataset {finewebedu,tinyshakespeare}] [--dataset-dir DATASET_DIR] [--micro-batch-size MICRO_BATCH_SIZE]
 10
 11Evaluate a model
 12
 13options:
 14  -h, --help            show this help message and exit
 15  --micro-batch-size MICRO_BATCH_SIZE
 16                        The number of batches of examples to pull from the dataset in each micro step.
 17
 18model:
 19  --pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}
 20                        The name of the pretrained model to use.
 21  --model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}
 22                        Use the specified model name configuration default values.
 23  --checkpoint CHECKPOINT
 24                        Load a model from a checkpoint.
 25  --device DEVICE       The device to use.
 26  --sequence-length SEQUENCE_LENGTH
 27                        The sequence length used for input content in each micro batch.
 28  --seed SEED           The seed to use for sampling/training.
 29  --compile, --no-compile
 30                        Will compile the model if supported by the device.
 31
 32eval:
 33  --validation-steps VALIDATION_STEPS
 34                        Number of validation loss iterations to perform each eval round.
 35  --hellaswag-samples HELLASWAG_SAMPLES
 36                        The number of HellaSwag evaluation results to sample or None for the entire set.
 37
 38dataset:
 39  --dataset {finewebedu,tinyshakespeare}
 40                        Use the specified dataset.
 41  --dataset-dir DATASET_DIR
 42                        Directory where the dataset is stored.
 43```
 44"""
 45
 46import argparse
 47import logging
 48
 49import torch
 50
 51from nano_gpt.datasets import hellaswag
 52from nano_gpt.datasets.data_loader import read_preprocessed_corpus
 53from nano_gpt import hellaswag_eval, trainer
 54from nano_gpt.config import DatasetConfig
 55from nano_gpt.log import create_log
 56
 57from .model_config import (
 58    create_model_arguments,
 59    model_from_args,
 60    eval_config_from_args,
 61    create_eval_arguments,
 62    create_dataset_arguments,
 63    dataset_config_from_args,
 64    load_checkpoint_context,
 65)
 66
 67_LOGGER = logging.getLogger(__name__)
 68
 69DATASET = "hellaswag"
 70SPLIT = "validation"
 71
 72
 73def create_arguments(args: argparse.ArgumentParser) -> None:
 74    """Get parsed passed in arguments."""
 75    create_model_arguments(args)
 76    create_eval_arguments(args)
 77    create_dataset_arguments(args)
 78
 79
 80def run(args: argparse.Namespace) -> int:
 81    """Run the eval command."""
 82    torch.set_float32_matmul_precision("high")
 83    log = create_log(None, log_stdout=True)
 84
 85    dataset_config: DatasetConfig | None = None
 86    with load_checkpoint_context(args) as checkpoint:
 87        eval_config = eval_config_from_args(args, checkpoint)
 88        _LOGGER.info(f"Eval config: {eval_config}")
 89        model, tokenizer, _ = model_from_args(args, checkpoint)
 90        model.to(args.device)
 91        model.eval()
 92
 93        if eval_config.validation_steps:
 94            dataset_config = dataset_config_from_args(args, checkpoint)
 95            _LOGGER.info(f"Dataset config: {dataset_config}")
 96
 97    if dataset_config is not None and dataset_config.dataset_name is not None:
 98        val_data_loader = read_preprocessed_corpus(
 99            dataset_config.dataset_path(SPLIT),
100            dataset_config,
101        )
102        val_ds = iter(val_data_loader)
103
104        worker_state = trainer.WorkerState(args.device)
105        _LOGGER.info("Worker state: %s", worker_state)
106
107        with torch.no_grad():
108            loss_accum = trainer.compute_loss(
109                model,
110                worker_state,
111                log_label="val",
112                ds=val_ds,
113                steps=eval_config.validation_steps,
114                backward=False,
115            )
116            log.log(
117                trainer.ValStats(
118                    step=eval_config.validation_steps, loss_accum=loss_accum
119                ).log_record()
120            )
121
122    hellaswag_val = hellaswag.load_dataset(SPLIT)
123    hellaswag_result = hellaswag_eval.evaluate(
124        model,
125        tokenizer,
126        hellaswag_val,
127        args.device,
128        eval_config.hellaswag_samples,
129    )
130    log.log(hellaswag_result.log_record())
131
132    return 0
DATASET = 'hellaswag'
SPLIT = 'validation'
def create_arguments(args: argparse.ArgumentParser) -> None:
74def create_arguments(args: argparse.ArgumentParser) -> None:
75    """Get parsed passed in arguments."""
76    create_model_arguments(args)
77    create_eval_arguments(args)
78    create_dataset_arguments(args)

Get parsed passed in arguments.

def run(args: argparse.Namespace) -> int:
 81def run(args: argparse.Namespace) -> int:
 82    """Run the eval command."""
 83    torch.set_float32_matmul_precision("high")
 84    log = create_log(None, log_stdout=True)
 85
 86    dataset_config: DatasetConfig | None = None
 87    with load_checkpoint_context(args) as checkpoint:
 88        eval_config = eval_config_from_args(args, checkpoint)
 89        _LOGGER.info(f"Eval config: {eval_config}")
 90        model, tokenizer, _ = model_from_args(args, checkpoint)
 91        model.to(args.device)
 92        model.eval()
 93
 94        if eval_config.validation_steps:
 95            dataset_config = dataset_config_from_args(args, checkpoint)
 96            _LOGGER.info(f"Dataset config: {dataset_config}")
 97
 98    if dataset_config is not None and dataset_config.dataset_name is not None:
 99        val_data_loader = read_preprocessed_corpus(
100            dataset_config.dataset_path(SPLIT),
101            dataset_config,
102        )
103        val_ds = iter(val_data_loader)
104
105        worker_state = trainer.WorkerState(args.device)
106        _LOGGER.info("Worker state: %s", worker_state)
107
108        with torch.no_grad():
109            loss_accum = trainer.compute_loss(
110                model,
111                worker_state,
112                log_label="val",
113                ds=val_ds,
114                steps=eval_config.validation_steps,
115                backward=False,
116            )
117            log.log(
118                trainer.ValStats(
119                    step=eval_config.validation_steps, loss_accum=loss_accum
120                ).log_record()
121            )
122
123    hellaswag_val = hellaswag.load_dataset(SPLIT)
124    hellaswag_result = hellaswag_eval.evaluate(
125        model,
126        tokenizer,
127        hellaswag_val,
128        args.device,
129        eval_config.hellaswag_samples,
130    )
131    log.log(hellaswag_result.log_record())
132
133    return 0

Run the eval command.