nano_gpt.tool.eval
Command-line interface for evaling a trained model.
Usage:
usage: nano-gpt eval [-h] [--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}]
[--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}] [--checkpoint CHECKPOINT] [--device DEVICE]
[--sequence-length SEQUENCE_LENGTH] [--seed SEED] [--compile | --no-compile]
[--validation-steps VALIDATION_STEPS] [--hellaswag-samples HELLASWAG_SAMPLES]
[--dataset {finewebedu,tinyshakespeare}] [--dataset-dir DATASET_DIR] [--micro-batch-size MICRO_BATCH_SIZE]
Evaluate a model
options:
-h, --help show this help message and exit
--micro-batch-size MICRO_BATCH_SIZE
The number of batches of examples to pull from the dataset in each micro step.
model:
--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}
The name of the pretrained model to use.
--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}
Use the specified model name configuration default values.
--checkpoint CHECKPOINT
Load a model from a checkpoint.
--device DEVICE The device to use.
--sequence-length SEQUENCE_LENGTH
The sequence length used for input content in each micro batch.
--seed SEED The seed to use for sampling/training.
--compile, --no-compile
Will compile the model if supported by the device.
eval:
--validation-steps VALIDATION_STEPS
Number of validation loss iterations to perform each eval round.
--hellaswag-samples HELLASWAG_SAMPLES
The number of HellaSwag evaluation results to sample or None for the entire set.
dataset:
--dataset {finewebedu,tinyshakespeare}
Use the specified dataset.
--dataset-dir DATASET_DIR
Directory where the dataset is stored.
1"""Command-line interface for evaling a trained model. 2 3Usage: 4``` 5usage: nano-gpt eval [-h] [--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}] 6 [--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}] [--checkpoint CHECKPOINT] [--device DEVICE] 7 [--sequence-length SEQUENCE_LENGTH] [--seed SEED] [--compile | --no-compile] 8 [--validation-steps VALIDATION_STEPS] [--hellaswag-samples HELLASWAG_SAMPLES] 9 [--dataset {finewebedu,tinyshakespeare}] [--dataset-dir DATASET_DIR] [--micro-batch-size MICRO_BATCH_SIZE] 10 11Evaluate a model 12 13options: 14 -h, --help show this help message and exit 15 --micro-batch-size MICRO_BATCH_SIZE 16 The number of batches of examples to pull from the dataset in each micro step. 17 18model: 19 --pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl} 20 The name of the pretrained model to use. 21 --model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs} 22 Use the specified model name configuration default values. 23 --checkpoint CHECKPOINT 24 Load a model from a checkpoint. 25 --device DEVICE The device to use. 26 --sequence-length SEQUENCE_LENGTH 27 The sequence length used for input content in each micro batch. 28 --seed SEED The seed to use for sampling/training. 29 --compile, --no-compile 30 Will compile the model if supported by the device. 31 32eval: 33 --validation-steps VALIDATION_STEPS 34 Number of validation loss iterations to perform each eval round. 35 --hellaswag-samples HELLASWAG_SAMPLES 36 The number of HellaSwag evaluation results to sample or None for the entire set. 37 38dataset: 39 --dataset {finewebedu,tinyshakespeare} 40 Use the specified dataset. 41 --dataset-dir DATASET_DIR 42 Directory where the dataset is stored. 43``` 44""" 45 46import argparse 47import logging 48 49import torch 50 51from nano_gpt.datasets import hellaswag 52from nano_gpt.datasets.data_loader import read_preprocessed_corpus 53from nano_gpt import hellaswag_eval, trainer 54from nano_gpt.config import DatasetConfig 55from nano_gpt.log import create_log 56 57from .model_config import ( 58 create_model_arguments, 59 model_from_args, 60 eval_config_from_args, 61 create_eval_arguments, 62 create_dataset_arguments, 63 dataset_config_from_args, 64 load_checkpoint_context, 65) 66 67_LOGGER = logging.getLogger(__name__) 68 69DATASET = "hellaswag" 70SPLIT = "validation" 71 72 73def create_arguments(args: argparse.ArgumentParser) -> None: 74 """Get parsed passed in arguments.""" 75 create_model_arguments(args) 76 create_eval_arguments(args) 77 create_dataset_arguments(args) 78 79 80def run(args: argparse.Namespace) -> int: 81 """Run the eval command.""" 82 torch.set_float32_matmul_precision("high") 83 log = create_log(None, log_stdout=True) 84 85 dataset_config: DatasetConfig | None = None 86 with load_checkpoint_context(args) as checkpoint: 87 eval_config = eval_config_from_args(args, checkpoint) 88 _LOGGER.info(f"Eval config: {eval_config}") 89 model, tokenizer, _ = model_from_args(args, checkpoint) 90 model.to(args.device) 91 model.eval() 92 93 if eval_config.validation_steps: 94 dataset_config = dataset_config_from_args(args, checkpoint) 95 _LOGGER.info(f"Dataset config: {dataset_config}") 96 97 if dataset_config is not None and dataset_config.dataset_name is not None: 98 val_data_loader = read_preprocessed_corpus( 99 dataset_config.dataset_path(SPLIT), 100 dataset_config, 101 ) 102 val_ds = iter(val_data_loader) 103 104 worker_state = trainer.WorkerState(args.device) 105 _LOGGER.info("Worker state: %s", worker_state) 106 107 with torch.no_grad(): 108 loss_accum = trainer.compute_loss( 109 model, 110 worker_state, 111 log_label="val", 112 ds=val_ds, 113 steps=eval_config.validation_steps, 114 backward=False, 115 ) 116 log.log( 117 trainer.ValStats( 118 step=eval_config.validation_steps, loss_accum=loss_accum 119 ).log_record() 120 ) 121 122 hellaswag_val = hellaswag.load_dataset(SPLIT) 123 hellaswag_result = hellaswag_eval.evaluate( 124 model, 125 tokenizer, 126 hellaswag_val, 127 args.device, 128 eval_config.hellaswag_samples, 129 ) 130 log.log(hellaswag_result.log_record()) 131 132 return 0
DATASET =
'hellaswag'
SPLIT =
'validation'
def
create_arguments(args: argparse.ArgumentParser) -> None:
74def create_arguments(args: argparse.ArgumentParser) -> None: 75 """Get parsed passed in arguments.""" 76 create_model_arguments(args) 77 create_eval_arguments(args) 78 create_dataset_arguments(args)
Get parsed passed in arguments.
def
run(args: argparse.Namespace) -> int:
81def run(args: argparse.Namespace) -> int: 82 """Run the eval command.""" 83 torch.set_float32_matmul_precision("high") 84 log = create_log(None, log_stdout=True) 85 86 dataset_config: DatasetConfig | None = None 87 with load_checkpoint_context(args) as checkpoint: 88 eval_config = eval_config_from_args(args, checkpoint) 89 _LOGGER.info(f"Eval config: {eval_config}") 90 model, tokenizer, _ = model_from_args(args, checkpoint) 91 model.to(args.device) 92 model.eval() 93 94 if eval_config.validation_steps: 95 dataset_config = dataset_config_from_args(args, checkpoint) 96 _LOGGER.info(f"Dataset config: {dataset_config}") 97 98 if dataset_config is not None and dataset_config.dataset_name is not None: 99 val_data_loader = read_preprocessed_corpus( 100 dataset_config.dataset_path(SPLIT), 101 dataset_config, 102 ) 103 val_ds = iter(val_data_loader) 104 105 worker_state = trainer.WorkerState(args.device) 106 _LOGGER.info("Worker state: %s", worker_state) 107 108 with torch.no_grad(): 109 loss_accum = trainer.compute_loss( 110 model, 111 worker_state, 112 log_label="val", 113 ds=val_ds, 114 steps=eval_config.validation_steps, 115 backward=False, 116 ) 117 log.log( 118 trainer.ValStats( 119 step=eval_config.validation_steps, loss_accum=loss_accum 120 ).log_record() 121 ) 122 123 hellaswag_val = hellaswag.load_dataset(SPLIT) 124 hellaswag_result = hellaswag_eval.evaluate( 125 model, 126 tokenizer, 127 hellaswag_val, 128 args.device, 129 eval_config.hellaswag_samples, 130 ) 131 log.log(hellaswag_result.log_record()) 132 133 return 0
Run the eval command.