nano_gpt.trainer

Trainer for nano-gpt.

This module provides a trainer for the nano-gpt model. It provides a main training loop that can be used to train the model on a dataset. It also provides a function for computing the loss on a dataset, and a class for managing the state of the training process.

This supports DDP for multi-GPU training. The training process is also resumable using checkpoints.

  1"""Trainer for nano-gpt.
  2
  3This module provides a trainer for the nano-gpt model. It provides a main training
  4loop that can be used to train the model on a dataset. It also provides a function
  5for computing the loss on a dataset, and a class for managing the state of the
  6training process.
  7
  8This supports DDP for multi-GPU training. The training process is also resumable
  9using checkpoints.
 10"""
 11
 12from collections.abc import Iterator, Iterable
 13import dataclasses
 14from dataclasses import dataclass
 15import logging
 16import math
 17import os
 18import pathlib
 19import time
 20from typing import Any
 21
 22import torch
 23from torch import nn
 24from torch.distributed import init_process_group
 25from torch.nn.parallel import DistributedDataParallel as DDP
 26import torch.distributed as dist
 27
 28from . import hellaswag_eval
 29from .model import sample, GPT
 30from .config import TrainConfig, EvalConfig, SampleConfig, DatasetConfig
 31from .datasets import hellaswag
 32from .checkpoint import save_checkpoint, Checkpoint
 33from .devices import get_dtype
 34from .log import LogRecord, create_log
 35
 36__all__ = [
 37    "train",
 38    "create_optimizer",
 39]
 40
 41
 42_LOGGER = logging.getLogger(__name__)
 43
 44
 45def get_lr(config: TrainConfig, step: int) -> float:
 46    """Learning rate based on the current step."""
 47    if step < config.warmup_steps:
 48        return config.max_lr * (step + 1) / config.warmup_steps
 49    if step > config.max_steps:
 50        return config.min_lr
 51    decay_ratio = (step - config.warmup_steps) / (
 52        config.max_steps - config.warmup_steps
 53    )
 54    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
 55    return config.min_lr + coeff * (config.max_lr - config.min_lr)
 56
 57
 58@dataclass(frozen=True, kw_only=True)
 59class ValStats:
 60    """Validation statistics for logging."""
 61
 62    step: int = 0
 63    loss_accum: float = 0.0
 64
 65    def log_record(self) -> LogRecord:
 66        """Log record."""
 67        return LogRecord(
 68            log_type="val",
 69            data={
 70                "step": self.step,
 71                "loss": f"{self.loss_accum:0.4f}",
 72            },
 73        )
 74
 75
 76class WorkerState:
 77    """State for multi-processing using Distributed Data Parallel."""
 78
 79    def __init__(self, device: str) -> None:
 80        """Initialize the state."""
 81        # set up DDP (distributed data parallel).
 82        # torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
 83        self.ddp = int(os.environ.get("RANK", -1)) != -1  # is this a ddp run?
 84        if self.ddp and "cuda" not in device:
 85            self.ddp = False
 86            _LOGGER.warning(
 87                "DDP requested but requested device is not cuda, disabling DDP"
 88            )
 89        if self.ddp:
 90            init_process_group(backend="nccl")
 91            self.ddp_rank = int(os.environ["RANK"])
 92            self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
 93            self.ddp_world_size = int(os.environ["WORLD_SIZE"])
 94            self.device = f"cuda:{self.ddp_local_rank}"
 95            torch.cuda.set_device(self.device)
 96        else:
 97            self.ddp_rank = 0
 98            self.ddp_local_rank = 0
 99            self.ddp_world_size = 1
100            self.device = device
101
102    @property
103    def is_cuda(self) -> bool:
104        """Check if the device is CUDA."""
105        return "cuda" in self.device
106
107    @property
108    def dtype(self) -> torch.dtype:
109        """Get the data type."""
110        return get_dtype(self.device)
111
112    @property
113    def is_primary(self) -> bool:
114        """The primary process will do logging, checkpointing, etc."""
115        return self.ddp_rank == 0
116
117    def __str__(self) -> str:
118        """String representation."""
119        return f"WorkerState(ddp={self.ddp}, ddp_rank={self.ddp_rank}, ddp_local_rank={self.ddp_local_rank}, ddp_world_size={self.ddp_world_size}, device={self.device})"
120
121
122def compute_loss(
123    model: nn.Module,
124    worker_state: WorkerState,
125    log_label: str,
126    ds: Iterator[tuple[torch.Tensor, torch.Tensor]],
127    steps: int,
128    backward: bool,
129) -> float:
130    """Compute the validation loss.
131
132    It is expected that the model is called in eval mode.
133    This will consume items from the dataset, so it needs
134    to be in the correct state before calling.
135    """
136    if not steps:
137        raise ValueError("steps must be greater than 0")
138    loss_accum = 0.0  # torch.zeros(1, device=worker_state.device)
139    for step in range(steps):
140        x, y = next(ds)
141        x, y = x.to(worker_state.device), y.to(worker_state.device)
142        if worker_state.ddp:
143            model.require_backward_grad_sync = step == (steps - 1)  # type: ignore[assignment]
144        with torch.autocast(device_type=worker_state.device, dtype=worker_state.dtype):
145            logits, loss = model(x, y)
146        loss = loss / steps
147        loss_accum += loss.detach().item()
148        if backward:
149            loss.backward()
150    return loss_accum
151
152
153@dataclass
154class TrainStats:
155    """Training statistics for logging."""
156
157    def __init__(self, config: TrainConfig) -> None:
158        """Initialize the training statistics."""
159        self.step = config.step
160        self.t0: float = 0.0
161        self.config = config
162        self.stats: dict[str, Any] = {}
163
164    def start_step(self) -> None:
165        """Start the step."""
166        self.t0 = time.time()
167
168    def end_step(self, loss: float, norm: float) -> None:
169        """Step the statistics."""
170        t1 = time.time()
171        dt = (t1 - self.t0) * 1000
172        tok_per_sec = self.config.total_batch_size / (t1 - self.t0)
173        lr = get_lr(self.config, self.step)
174        self.stats.update(
175            {
176                "step": self.step,
177                "loss": f"{loss:0.4f}",
178                "norm": f"{norm:0.4f}",
179                "dt": f"{dt:0.2f}ms",
180                "tok/sec": f"{tok_per_sec:0.2f}",
181                "lr": f"{lr:0.6f}",
182            }
183        )
184        self.step += 1
185
186    def log_record(self) -> LogRecord:
187        """Log record."""
188        return LogRecord(
189            log_type="train",
190            data=self.stats,
191        )
192
193
194def create_optimizer(
195    raw_model: GPT,
196    config: TrainConfig,
197    checkpoint: Checkpoint | None,
198    is_cuda: bool,
199) -> torch.optim.Optimizer:
200    """Create the optimizer with the option to resume from a checkpoint."""
201    optimizer = raw_model.configure_optimizers(
202        weight_decay=0.1,
203        learning_rate=get_lr(config, 0),
204        use_fused=is_cuda,
205    )
206    if checkpoint is not None and checkpoint.optimizer_state_dict is not None:
207        _LOGGER.info("Loading optimizer state from checkpoint")
208        optimizer.load_state_dict(checkpoint.optimizer_state_dict)
209    _LOGGER.debug("Optimizer: %s", optimizer.state_dict())
210    return optimizer
211
212
213def train(
214    raw_model: GPT,
215    optimizer: torch.optim.Optimizer,
216    worker_state: WorkerState,
217    config: TrainConfig,
218    train_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]],
219    eval_config: EvalConfig | None = None,
220    dataset_config: DatasetConfig | None = None,
221    val_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]] | None = None,
222    hellaswag_loader: Iterable[hellaswag.Sample] | None = None,
223    sample_config: SampleConfig | None = None,
224) -> None:
225    """Train the model.
226
227    This is the main training loop. It will train the model for the number of steps
228    specified in the config. It will also evaluate the model on the validation set
229    and save checkpoints.
230    """
231    config.log_info(worker_state.ddp_world_size)
232    if worker_state.is_primary:
233        log = create_log(
234            pathlib.Path(config.log_file) if config.log_file else None, log_stdout=True
235        )
236    else:
237        log = create_log(None, False)
238
239    model: nn.Module = raw_model
240    tokenizer = raw_model.enc
241    model.to(worker_state.device)
242    if worker_state.ddp:
243        model = DDP(model, device_ids=[worker_state.ddp_local_rank])
244
245    train_ds = iter(train_data_loader)
246    stats = TrainStats(config)
247    for step in range(config.step, config.max_steps):
248        last_step = step == config.max_steps - 1
249        stats.start_step()
250
251        val_stats: ValStats | None = None
252        eval_step = step % config.eval_steps == 0
253        checkpoint_step = step % config.checkpoint_steps == 0
254        if (
255            (eval_step or last_step)
256            and val_data_loader is not None
257            and eval_config is not None
258            and eval_config.validation_steps
259        ):
260            model.eval()
261            val_ds = iter(val_data_loader)
262            with torch.no_grad():
263                val_loss_accum = compute_loss(
264                    model,
265                    worker_state,
266                    "val",
267                    val_ds,
268                    eval_config.validation_steps,
269                    backward=False,
270                )
271            if worker_state.ddp:
272                vall_loss_tensor = torch.tensor(
273                    val_loss_accum, device=worker_state.device
274                )
275                dist.all_reduce(vall_loss_tensor, op=dist.ReduceOp.AVG)
276            val_stats = ValStats(step=step, loss_accum=val_loss_accum)
277            if worker_state.is_primary:
278                log.log(val_stats.log_record())
279
280        if (
281            step != 0
282            and (step != config.step)  # don't save the initial checkpoint
283            and (checkpoint_step or last_step)
284            and worker_state.is_primary
285            and config.checkpoint_dir is not None
286        ):
287            checkpoint_path = (
288                pathlib.Path(config.checkpoint_dir) / f"checkpoint_{step:06d}.bin"
289            )
290            checkpoint: Checkpoint = Checkpoint(
291                model_state_dict=raw_model.state_dict(),
292                config=raw_model.config,
293                step=step,
294                val_loss_accum=(
295                    val_stats.loss_accum if val_stats is not None else None
296                ),
297                optimizer_state_dict=optimizer.state_dict(),
298                train_config=config,
299                dataset_config=dataset_config,
300                eval_config=eval_config,
301                sample_config=sample_config,
302            )
303            save_checkpoint(checkpoint, checkpoint_path)
304        if (
305            step != 0
306            and (eval_step or last_step)
307            and hellaswag_loader is not None
308            and eval_config is not None
309            and eval_config.hellaswag_samples
310        ):
311            model.eval()
312            with torch.no_grad():
313                hellaswag_result = hellaswag_eval.evaluate(
314                    model,
315                    tokenizer,
316                    hellaswag_loader,
317                    worker_state.device,
318                    eval_config.hellaswag_samples,
319                )
320            if worker_state.ddp:
321                num_total = torch.tensor(
322                    hellaswag_result.total,
323                    dtype=torch.long,
324                    device=worker_state.device,
325                )
326                num_correct_norm = torch.tensor(
327                    hellaswag_result.correct,
328                    dtype=torch.long,
329                    device=worker_state.device,
330                )
331                dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
332                dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
333                hellaswag_result = dataclasses.replace(
334                    hellaswag_result,
335                    total=int(num_total.item()),
336                    correct=int(num_correct_norm.item()),
337                )
338            if worker_state.is_primary:
339                log.log(hellaswag_result.log_record())
340        if (
341            step > 0
342            and eval_step
343            and sample_config is not None
344            and sample_config.num_return_sequences
345        ):
346            model.eval()
347            with torch.no_grad():
348                samples = sample(
349                    model,
350                    tokenizer,
351                    sample_config.text,
352                    num_return_sequences=sample_config.num_return_sequences,
353                    max_length=sample_config.max_length,
354                    device=worker_state.device,
355                    seed=sample_config.seed + worker_state.ddp_rank,
356                )
357            for i, s in enumerate(samples):
358                print(f"rank {worker_state.ddp_rank} sample {i}: {s}")
359
360        model.train()
361        optimizer.zero_grad()
362        loss_accum = compute_loss(
363            model,
364            worker_state,
365            "train",
366            train_ds,
367            config.grad_accum_steps(worker_state.ddp_world_size),
368            backward=True,
369        )
370        if worker_state.ddp:
371            loss_tensor = torch.tensor(loss_accum, device=worker_state.device)
372            dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
373
374        # Prevent the model from getting large shocks of gradient magnitude
375        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
376
377        # Update the learning rate based on the step
378        lr = get_lr(config, step)
379        for param_group in optimizer.param_groups:
380            param_group["lr"] = lr
381        optimizer.step()
382        if worker_state.is_cuda:
383            torch.cuda.synchronize()
384
385        stats.end_step(loss_accum, norm.item())
386        if worker_state.is_primary:
387            log.log(stats.log_record())
def train( raw_model: nano_gpt.model.GPT, optimizer: torch.optim.optimizer.Optimizer, worker_state: nano_gpt.trainer.WorkerState, config: nano_gpt.config.TrainConfig, train_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]], eval_config: nano_gpt.config.EvalConfig | None = None, dataset_config: nano_gpt.config.DatasetConfig | None = None, val_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]] | None = None, hellaswag_loader: Iterable[nano_gpt.datasets.hellaswag.Sample] | None = None, sample_config: nano_gpt.config.SampleConfig | None = None) -> None:
214def train(
215    raw_model: GPT,
216    optimizer: torch.optim.Optimizer,
217    worker_state: WorkerState,
218    config: TrainConfig,
219    train_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]],
220    eval_config: EvalConfig | None = None,
221    dataset_config: DatasetConfig | None = None,
222    val_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]] | None = None,
223    hellaswag_loader: Iterable[hellaswag.Sample] | None = None,
224    sample_config: SampleConfig | None = None,
225) -> None:
226    """Train the model.
227
228    This is the main training loop. It will train the model for the number of steps
229    specified in the config. It will also evaluate the model on the validation set
230    and save checkpoints.
231    """
232    config.log_info(worker_state.ddp_world_size)
233    if worker_state.is_primary:
234        log = create_log(
235            pathlib.Path(config.log_file) if config.log_file else None, log_stdout=True
236        )
237    else:
238        log = create_log(None, False)
239
240    model: nn.Module = raw_model
241    tokenizer = raw_model.enc
242    model.to(worker_state.device)
243    if worker_state.ddp:
244        model = DDP(model, device_ids=[worker_state.ddp_local_rank])
245
246    train_ds = iter(train_data_loader)
247    stats = TrainStats(config)
248    for step in range(config.step, config.max_steps):
249        last_step = step == config.max_steps - 1
250        stats.start_step()
251
252        val_stats: ValStats | None = None
253        eval_step = step % config.eval_steps == 0
254        checkpoint_step = step % config.checkpoint_steps == 0
255        if (
256            (eval_step or last_step)
257            and val_data_loader is not None
258            and eval_config is not None
259            and eval_config.validation_steps
260        ):
261            model.eval()
262            val_ds = iter(val_data_loader)
263            with torch.no_grad():
264                val_loss_accum = compute_loss(
265                    model,
266                    worker_state,
267                    "val",
268                    val_ds,
269                    eval_config.validation_steps,
270                    backward=False,
271                )
272            if worker_state.ddp:
273                vall_loss_tensor = torch.tensor(
274                    val_loss_accum, device=worker_state.device
275                )
276                dist.all_reduce(vall_loss_tensor, op=dist.ReduceOp.AVG)
277            val_stats = ValStats(step=step, loss_accum=val_loss_accum)
278            if worker_state.is_primary:
279                log.log(val_stats.log_record())
280
281        if (
282            step != 0
283            and (step != config.step)  # don't save the initial checkpoint
284            and (checkpoint_step or last_step)
285            and worker_state.is_primary
286            and config.checkpoint_dir is not None
287        ):
288            checkpoint_path = (
289                pathlib.Path(config.checkpoint_dir) / f"checkpoint_{step:06d}.bin"
290            )
291            checkpoint: Checkpoint = Checkpoint(
292                model_state_dict=raw_model.state_dict(),
293                config=raw_model.config,
294                step=step,
295                val_loss_accum=(
296                    val_stats.loss_accum if val_stats is not None else None
297                ),
298                optimizer_state_dict=optimizer.state_dict(),
299                train_config=config,
300                dataset_config=dataset_config,
301                eval_config=eval_config,
302                sample_config=sample_config,
303            )
304            save_checkpoint(checkpoint, checkpoint_path)
305        if (
306            step != 0
307            and (eval_step or last_step)
308            and hellaswag_loader is not None
309            and eval_config is not None
310            and eval_config.hellaswag_samples
311        ):
312            model.eval()
313            with torch.no_grad():
314                hellaswag_result = hellaswag_eval.evaluate(
315                    model,
316                    tokenizer,
317                    hellaswag_loader,
318                    worker_state.device,
319                    eval_config.hellaswag_samples,
320                )
321            if worker_state.ddp:
322                num_total = torch.tensor(
323                    hellaswag_result.total,
324                    dtype=torch.long,
325                    device=worker_state.device,
326                )
327                num_correct_norm = torch.tensor(
328                    hellaswag_result.correct,
329                    dtype=torch.long,
330                    device=worker_state.device,
331                )
332                dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
333                dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
334                hellaswag_result = dataclasses.replace(
335                    hellaswag_result,
336                    total=int(num_total.item()),
337                    correct=int(num_correct_norm.item()),
338                )
339            if worker_state.is_primary:
340                log.log(hellaswag_result.log_record())
341        if (
342            step > 0
343            and eval_step
344            and sample_config is not None
345            and sample_config.num_return_sequences
346        ):
347            model.eval()
348            with torch.no_grad():
349                samples = sample(
350                    model,
351                    tokenizer,
352                    sample_config.text,
353                    num_return_sequences=sample_config.num_return_sequences,
354                    max_length=sample_config.max_length,
355                    device=worker_state.device,
356                    seed=sample_config.seed + worker_state.ddp_rank,
357                )
358            for i, s in enumerate(samples):
359                print(f"rank {worker_state.ddp_rank} sample {i}: {s}")
360
361        model.train()
362        optimizer.zero_grad()
363        loss_accum = compute_loss(
364            model,
365            worker_state,
366            "train",
367            train_ds,
368            config.grad_accum_steps(worker_state.ddp_world_size),
369            backward=True,
370        )
371        if worker_state.ddp:
372            loss_tensor = torch.tensor(loss_accum, device=worker_state.device)
373            dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
374
375        # Prevent the model from getting large shocks of gradient magnitude
376        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
377
378        # Update the learning rate based on the step
379        lr = get_lr(config, step)
380        for param_group in optimizer.param_groups:
381            param_group["lr"] = lr
382        optimizer.step()
383        if worker_state.is_cuda:
384            torch.cuda.synchronize()
385
386        stats.end_step(loss_accum, norm.item())
387        if worker_state.is_primary:
388            log.log(stats.log_record())

Train the model.

This is the main training loop. It will train the model for the number of steps specified in the config. It will also evaluate the model on the validation set and save checkpoints.

def create_optimizer( raw_model: nano_gpt.model.GPT, config: nano_gpt.config.TrainConfig, checkpoint: nano_gpt.checkpoint.Checkpoint | None, is_cuda: bool) -> torch.optim.optimizer.Optimizer:
195def create_optimizer(
196    raw_model: GPT,
197    config: TrainConfig,
198    checkpoint: Checkpoint | None,
199    is_cuda: bool,
200) -> torch.optim.Optimizer:
201    """Create the optimizer with the option to resume from a checkpoint."""
202    optimizer = raw_model.configure_optimizers(
203        weight_decay=0.1,
204        learning_rate=get_lr(config, 0),
205        use_fused=is_cuda,
206    )
207    if checkpoint is not None and checkpoint.optimizer_state_dict is not None:
208        _LOGGER.info("Loading optimizer state from checkpoint")
209        optimizer.load_state_dict(checkpoint.optimizer_state_dict)
210    _LOGGER.debug("Optimizer: %s", optimizer.state_dict())
211    return optimizer

Create the optimizer with the option to resume from a checkpoint.