nano_gpt.tool.sample

Command-line interface for sampling from a trained model.

Usage:

usage: nano-gpt sample [-h] [--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}]
                       [--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}] [--checkpoint CHECKPOINT] [--device DEVICE]
                       [--sequence-length SEQUENCE_LENGTH] [--seed SEED] [--compile | --no-compile]
                       [--sample-num-sequences SAMPLE_NUM_SEQUENCES] [--sample-max-length SAMPLE_MAX_LENGTH]
                       [--sample-seed SAMPLE_SEED]
                       [text ...]

Sample from a model

positional arguments:
  text                  The text to use as a prompt for sampling.

options:
  -h, --help            show this help message and exit

model:
  --pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}
                        The name of the pretrained model to use.
  --model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}
                        Use the specified model name configuration default values.
  --checkpoint CHECKPOINT
                        Load a model from a checkpoint.
  --device DEVICE       The device to use.
  --sequence-length SEQUENCE_LENGTH
                        The sequence length used for input content in each micro batch.
  --seed SEED           The seed to use for sampling/training.
  --compile, --no-compile
                        Will compile the model if supported by the device.

sample:
  --sample-num-sequences SAMPLE_NUM_SEQUENCES
                        The number of sequences to generate.
  --sample-max-length SAMPLE_MAX_LENGTH
                        The maximum length of the generated sequences.
  --sample-seed SAMPLE_SEED
                        The seed to use for sampling.
  1"""Command-line interface for sampling from a trained model.
  2
  3Usage:
  4```
  5usage: nano-gpt sample [-h] [--pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}]
  6                       [--model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}] [--checkpoint CHECKPOINT] [--device DEVICE]
  7                       [--sequence-length SEQUENCE_LENGTH] [--seed SEED] [--compile | --no-compile]
  8                       [--sample-num-sequences SAMPLE_NUM_SEQUENCES] [--sample-max-length SAMPLE_MAX_LENGTH]
  9                       [--sample-seed SAMPLE_SEED]
 10                       [text ...]
 11
 12Sample from a model
 13
 14positional arguments:
 15  text                  The text to use as a prompt for sampling.
 16
 17options:
 18  -h, --help            show this help message and exit
 19
 20model:
 21  --pretrained {gpt2,gpt2-large,gpt2-medium,gpt2-xl}
 22                        The name of the pretrained model to use.
 23  --model {gpt2,gpt2-large,gpt2-medium,gpt2-xl,gpt2-xs,gpt2-xxs}
 24                        Use the specified model name configuration default values.
 25  --checkpoint CHECKPOINT
 26                        Load a model from a checkpoint.
 27  --device DEVICE       The device to use.
 28  --sequence-length SEQUENCE_LENGTH
 29                        The sequence length used for input content in each micro batch.
 30  --seed SEED           The seed to use for sampling/training.
 31  --compile, --no-compile
 32                        Will compile the model if supported by the device.
 33
 34sample:
 35  --sample-num-sequences SAMPLE_NUM_SEQUENCES
 36                        The number of sequences to generate.
 37  --sample-max-length SAMPLE_MAX_LENGTH
 38                        The maximum length of the generated sequences.
 39  --sample-seed SAMPLE_SEED
 40                        The seed to use for sampling.
 41```
 42"""
 43
 44import argparse
 45import logging
 46import dataclasses
 47
 48import torch
 49
 50from nano_gpt.model import sample
 51
 52from .model_config import (
 53    create_model_arguments,
 54    model_from_args,
 55    create_sample_arguments,
 56    sample_config_from_args,
 57    load_checkpoint_context,
 58)
 59
 60
 61_LOGGER = logging.getLogger(__name__)
 62
 63
 64def create_arguments(args: argparse.ArgumentParser) -> None:
 65    """Get parsed passed in arguments."""
 66    create_model_arguments(args, default_values={"seed": 42, "pretrained": "gpt2"})
 67    create_sample_arguments(args)
 68    args.add_argument(
 69        "text",
 70        type=str,
 71        nargs="*",
 72        default=["Hello, I'm a language model,"],
 73        help="The text to use as a prompt for sampling.",
 74    )
 75
 76
 77def run(args: argparse.Namespace) -> int:
 78    """Run the sample command."""
 79    with load_checkpoint_context(args) as checkpoint:
 80        sample_config = sample_config_from_args(args, checkpoint)
 81        sample_config = dataclasses.replace(
 82            sample_config,
 83            text=" ".join(args.text),
 84        )
 85        _LOGGER.info(f"Sample config: {sample_config}")
 86
 87        model, _, _ = model_from_args(args, checkpoint)
 88
 89    model.to(args.device)
 90    model.eval()
 91
 92    print(args.text)
 93    with torch.no_grad():
 94        samples = sample(
 95            model,
 96            model.enc,
 97            sample_config.text,
 98            num_return_sequences=sample_config.num_return_sequences,
 99            max_length=sample_config.max_length,
100            device=args.device,
101            seed=sample_config.seed,
102        )
103    for s in samples:
104        print(">", s)
105
106    return 0
def create_arguments(args: argparse.ArgumentParser) -> None:
65def create_arguments(args: argparse.ArgumentParser) -> None:
66    """Get parsed passed in arguments."""
67    create_model_arguments(args, default_values={"seed": 42, "pretrained": "gpt2"})
68    create_sample_arguments(args)
69    args.add_argument(
70        "text",
71        type=str,
72        nargs="*",
73        default=["Hello, I'm a language model,"],
74        help="The text to use as a prompt for sampling.",
75    )

Get parsed passed in arguments.

def run(args: argparse.Namespace) -> int:
 78def run(args: argparse.Namespace) -> int:
 79    """Run the sample command."""
 80    with load_checkpoint_context(args) as checkpoint:
 81        sample_config = sample_config_from_args(args, checkpoint)
 82        sample_config = dataclasses.replace(
 83            sample_config,
 84            text=" ".join(args.text),
 85        )
 86        _LOGGER.info(f"Sample config: {sample_config}")
 87
 88        model, _, _ = model_from_args(args, checkpoint)
 89
 90    model.to(args.device)
 91    model.eval()
 92
 93    print(args.text)
 94    with torch.no_grad():
 95        samples = sample(
 96            model,
 97            model.enc,
 98            sample_config.text,
 99            num_return_sequences=sample_config.num_return_sequences,
100            max_length=sample_config.max_length,
101            device=args.device,
102            seed=sample_config.seed,
103        )
104    for s in samples:
105        print(">", s)
106
107    return 0

Run the sample command.