nano_gpt.tool.prepare_dataset

Command-line interface for preparing datasets for training runs.

Usage:

usage: nano-gpt prepare_dataset [-h] --dataset {finewebedu,tinyshakespeare} [--splits SPLITS] [--tokens-per-shard TOKENS_PER_SHARD]
                                [--dataset-dir DATASET_DIR] [--num-procs NUM_PROCS]

Evaluate a model

options:
  -h, --help            show this help message and exit
  --dataset {finewebedu,tinyshakespeare}
                        Use the specified dataset.
  --splits SPLITS       Use the specified dataset.
  --tokens-per-shard TOKENS_PER_SHARD
                        Number of tokens per shard.
  --dataset-dir DATASET_DIR
                        Directory to store the dataset.
  --num-procs NUM_PROCS
                        Number of processes to use for preprocessing.
  1"""Command-line interface for preparing datasets for training runs.
  2
  3Usage:
  4```
  5usage: nano-gpt prepare_dataset [-h] --dataset {finewebedu,tinyshakespeare} [--splits SPLITS] [--tokens-per-shard TOKENS_PER_SHARD]
  6                                [--dataset-dir DATASET_DIR] [--num-procs NUM_PROCS]
  7
  8Evaluate a model
  9
 10options:
 11  -h, --help            show this help message and exit
 12  --dataset {finewebedu,tinyshakespeare}
 13                        Use the specified dataset.
 14  --splits SPLITS       Use the specified dataset.
 15  --tokens-per-shard TOKENS_PER_SHARD
 16                        Number of tokens per shard.
 17  --dataset-dir DATASET_DIR
 18                        Directory to store the dataset.
 19  --num-procs NUM_PROCS
 20                        Number of processes to use for preprocessing.
 21```
 22"""
 23
 24import argparse
 25import logging
 26import os
 27import pathlib
 28
 29from nano_gpt.datasets import TRAIN_DATASETS
 30from nano_gpt.datasets.data_loader import preprocess_corpus, SPLITS
 31from nano_gpt.tokenizer import get_document_tokenizer
 32
 33from .model_config import DATASET_DIR
 34
 35_LOGGER = logging.getLogger(__name__)
 36
 37
 38def create_arguments(args: argparse.ArgumentParser) -> None:
 39    """Get parsed passed in arguments."""
 40    args.add_argument(
 41        "--dataset",
 42        type=str,
 43        help="Use the specified dataset.",
 44        choices=sorted(TRAIN_DATASETS.keys()),
 45        required=True,
 46    )
 47    args.add_argument(
 48        "--splits",
 49        type=str,
 50        help="Use the specified dataset.",
 51        default=",".join(SPLITS),
 52    )
 53    args.add_argument(
 54        "--tokens-per-shard",
 55        type=int,
 56        help="Number of tokens per shard.",
 57        default=10e8,  # 100 million tokens/shard
 58    )
 59    args.add_argument(
 60        "--dataset-dir",
 61        type=str,
 62        help="Directory to store the dataset.",
 63        default=DATASET_DIR,
 64    )
 65    default_cpu_count = 1
 66    if (cnt := os.cpu_count()) is not None:
 67        default_cpu_count = max(default_cpu_count, cnt // 2)
 68    args.add_argument(
 69        "--num-procs",
 70        type=int,
 71        help="Number of processes to use for preprocessing.",
 72        default=default_cpu_count,
 73    )
 74
 75
 76def run(args: argparse.Namespace) -> int:
 77    """Run the sample command."""
 78
 79    dataset_dir = pathlib.Path(args.dataset_dir)
 80    dataset_dir.mkdir(parents=True, exist_ok=True)
 81
 82    tokenizer = get_document_tokenizer()
 83
 84    dataset = TRAIN_DATASETS[args.dataset]
 85    _LOGGER.info("Loading dataset %s", args.dataset)
 86
 87    splits = args.splits.split(",")
 88    for split in splits:
 89        if split not in SPLITS:
 90            raise ValueError(f"Invalid split {split}, must be one of {SPLITS}")
 91        _LOGGER.info("Loading dataset %s for split %s", args.dataset, split)
 92        output_path = dataset_dir / f"{args.dataset}_{split}.npy"
 93        ds = dataset.load_fn(split=split, streaming=False)
 94        preprocess_corpus(
 95            ds,
 96            tokenizer,
 97            output_path,
 98            num_procs=max(args.num_procs, 1),
 99            tokens_per_shard=dataset.tokens_per_shard,
100        )
101
102    return 0
def create_arguments(args: argparse.ArgumentParser) -> None:
39def create_arguments(args: argparse.ArgumentParser) -> None:
40    """Get parsed passed in arguments."""
41    args.add_argument(
42        "--dataset",
43        type=str,
44        help="Use the specified dataset.",
45        choices=sorted(TRAIN_DATASETS.keys()),
46        required=True,
47    )
48    args.add_argument(
49        "--splits",
50        type=str,
51        help="Use the specified dataset.",
52        default=",".join(SPLITS),
53    )
54    args.add_argument(
55        "--tokens-per-shard",
56        type=int,
57        help="Number of tokens per shard.",
58        default=10e8,  # 100 million tokens/shard
59    )
60    args.add_argument(
61        "--dataset-dir",
62        type=str,
63        help="Directory to store the dataset.",
64        default=DATASET_DIR,
65    )
66    default_cpu_count = 1
67    if (cnt := os.cpu_count()) is not None:
68        default_cpu_count = max(default_cpu_count, cnt // 2)
69    args.add_argument(
70        "--num-procs",
71        type=int,
72        help="Number of processes to use for preprocessing.",
73        default=default_cpu_count,
74    )

Get parsed passed in arguments.

def run(args: argparse.Namespace) -> int:
 77def run(args: argparse.Namespace) -> int:
 78    """Run the sample command."""
 79
 80    dataset_dir = pathlib.Path(args.dataset_dir)
 81    dataset_dir.mkdir(parents=True, exist_ok=True)
 82
 83    tokenizer = get_document_tokenizer()
 84
 85    dataset = TRAIN_DATASETS[args.dataset]
 86    _LOGGER.info("Loading dataset %s", args.dataset)
 87
 88    splits = args.splits.split(",")
 89    for split in splits:
 90        if split not in SPLITS:
 91            raise ValueError(f"Invalid split {split}, must be one of {SPLITS}")
 92        _LOGGER.info("Loading dataset %s for split %s", args.dataset, split)
 93        output_path = dataset_dir / f"{args.dataset}_{split}.npy"
 94        ds = dataset.load_fn(split=split, streaming=False)
 95        preprocess_corpus(
 96            ds,
 97            tokenizer,
 98            output_path,
 99            num_procs=max(args.num_procs, 1),
100            tokens_per_shard=dataset.tokens_per_shard,
101        )
102
103    return 0

Run the sample command.