nano_gpt.tool.sample
Command-line interface for sampling from a trained model.
Usage:
usage: nano-gpt sample [-h] [--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}]
[--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}] [--checkpoint CHECKPOINT] [--device DEVICE]
[--sequence-length SEQUENCE_LENGTH] [--seed SEED] [--compile | --no-compile]
[--sample-num-sequences SAMPLE_NUM_SEQUENCES] [--sample-max-length SAMPLE_MAX_LENGTH]
[--sample-seed SAMPLE_SEED]
[text ...]
Sample from a model
positional arguments:
text The text to use as a prompt for sampling.
options:
-h, --help show this help message and exit
model:
--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}
The name of the pretrained model to use.
--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}
Use the specified model name configuration default values.
--checkpoint CHECKPOINT
Load a model from a checkpoint.
--device DEVICE The device to use.
--sequence-length SEQUENCE_LENGTH
The sequence length used for input content in each micro batch.
--seed SEED The seed to use for sampling/training.
--compile, --no-compile
Will compile the model if supported by the device.
sample:
--sample-num-sequences SAMPLE_NUM_SEQUENCES
The number of sequences to generate.
--sample-max-length SAMPLE_MAX_LENGTH
The maximum length of the generated sequences.
--sample-seed SAMPLE_SEED
The seed to use for sampling.
1"""Command-line interface for sampling from a trained model. 2 3Usage: 4``` 5usage: nano-gpt sample [-h] [--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}] 6 [--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}] [--checkpoint CHECKPOINT] [--device DEVICE] 7 [--sequence-length SEQUENCE_LENGTH] [--seed SEED] [--compile | --no-compile] 8 [--sample-num-sequences SAMPLE_NUM_SEQUENCES] [--sample-max-length SAMPLE_MAX_LENGTH] 9 [--sample-seed SAMPLE_SEED] 10 [text ...] 11 12Sample from a model 13 14positional arguments: 15 text The text to use as a prompt for sampling. 16 17options: 18 -h, --help show this help message and exit 19 20model: 21 --pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl} 22 The name of the pretrained model to use. 23 --model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs} 24 Use the specified model name configuration default values. 25 --checkpoint CHECKPOINT 26 Load a model from a checkpoint. 27 --device DEVICE The device to use. 28 --sequence-length SEQUENCE_LENGTH 29 The sequence length used for input content in each micro batch. 30 --seed SEED The seed to use for sampling/training. 31 --compile, --no-compile 32 Will compile the model if supported by the device. 33 34sample: 35 --sample-num-sequences SAMPLE_NUM_SEQUENCES 36 The number of sequences to generate. 37 --sample-max-length SAMPLE_MAX_LENGTH 38 The maximum length of the generated sequences. 39 --sample-seed SAMPLE_SEED 40 The seed to use for sampling. 41``` 42""" 43 44import argparse 45import logging 46import dataclasses 47 48import torch 49 50from nano_gpt.model import sample 51 52from .model_config import ( 53 create_model_arguments, 54 model_from_args, 55 create_sample_arguments, 56 sample_config_from_args, 57 load_checkpoint_context, 58) 59 60 61_LOGGER = logging.getLogger(__name__) 62 63 64def create_arguments(args: argparse.ArgumentParser) -> None: 65 """Get parsed passed in arguments.""" 66 create_model_arguments(args, default_values={"seed": 42, "pretrained": "gpt2"}) 67 create_sample_arguments(args) 68 args.add_argument( 69 "text", 70 type=str, 71 nargs="*", 72 default=["Hello, I'm a language model,"], 73 help="The text to use as a prompt for sampling.", 74 ) 75 76 77def run(args: argparse.Namespace) -> int: 78 """Run the sample command.""" 79 with load_checkpoint_context(args) as checkpoint: 80 sample_config = sample_config_from_args(args, checkpoint) 81 sample_config = dataclasses.replace( 82 sample_config, 83 text=" ".join(args.text), 84 ) 85 _LOGGER.info(f"Sample config: {sample_config}") 86 87 model, _, _ = model_from_args(args, checkpoint) 88 89 model.to(args.device) 90 model.eval() 91 92 print(args.text) 93 with torch.no_grad(): 94 samples = sample( 95 model, 96 model.enc, 97 sample_config.text, 98 num_return_sequences=sample_config.num_return_sequences, 99 max_length=sample_config.max_length, 100 device=args.device, 101 seed=sample_config.seed, 102 ) 103 for s in samples: 104 print(">", s) 105 106 return 0
def
create_arguments(args: argparse.ArgumentParser) -> None:
65def create_arguments(args: argparse.ArgumentParser) -> None: 66 """Get parsed passed in arguments.""" 67 create_model_arguments(args, default_values={"seed": 42, "pretrained": "gpt2"}) 68 create_sample_arguments(args) 69 args.add_argument( 70 "text", 71 type=str, 72 nargs="*", 73 default=["Hello, I'm a language model,"], 74 help="The text to use as a prompt for sampling.", 75 )
Get parsed passed in arguments.
def
run(args: argparse.Namespace) -> int:
78def run(args: argparse.Namespace) -> int: 79 """Run the sample command.""" 80 with load_checkpoint_context(args) as checkpoint: 81 sample_config = sample_config_from_args(args, checkpoint) 82 sample_config = dataclasses.replace( 83 sample_config, 84 text=" ".join(args.text), 85 ) 86 _LOGGER.info(f"Sample config: {sample_config}") 87 88 model, _, _ = model_from_args(args, checkpoint) 89 90 model.to(args.device) 91 model.eval() 92 93 print(args.text) 94 with torch.no_grad(): 95 samples = sample( 96 model, 97 model.enc, 98 sample_config.text, 99 num_return_sequences=sample_config.num_return_sequences, 100 max_length=sample_config.max_length, 101 device=args.device, 102 seed=sample_config.seed, 103 ) 104 for s in samples: 105 print(">", s) 106 107 return 0
Run the sample command.