nano_gpt.trainer
Trainer for nano-gpt.
This module provides a trainer for the nano-gpt model. It provides a main training loop that can be used to train the model on a dataset. It also provides a function for computing the loss on a dataset, and a class for managing the state of the training process.
This supports DDP for multi-GPU training. The training process is also resumable using checkpoints.
1"""Trainer for nano-gpt. 2 3This module provides a trainer for the nano-gpt model. It provides a main training 4loop that can be used to train the model on a dataset. It also provides a function 5for computing the loss on a dataset, and a class for managing the state of the 6training process. 7 8This supports DDP for multi-GPU training. The training process is also resumable 9using checkpoints. 10""" 11 12from collections.abc import Iterator, Iterable 13import dataclasses 14from dataclasses import dataclass 15import logging 16import math 17import os 18import pathlib 19import time 20from typing import Any 21 22import torch 23from torch import nn 24from torch.distributed import init_process_group 25from torch.nn.parallel import DistributedDataParallel as DDP 26import torch.distributed as dist 27 28from . import hellaswag_eval 29from .model import sample, GPT 30from .config import TrainConfig, EvalConfig, SampleConfig, DatasetConfig 31from .datasets import hellaswag 32from .checkpoint import save_checkpoint, Checkpoint 33from .devices import get_dtype 34from .log import LogRecord, create_log 35 36__all__ = [ 37 "train", 38 "create_optimizer", 39] 40 41 42_LOGGER = logging.getLogger(__name__) 43 44 45def get_lr(config: TrainConfig, step: int) -> float: 46 """Learning rate based on the current step.""" 47 if step < config.warmup_steps: 48 return config.max_lr * (step + 1) / config.warmup_steps 49 if step > config.max_steps: 50 return config.min_lr 51 decay_ratio = (step - config.warmup_steps) / ( 52 config.max_steps - config.warmup_steps 53 ) 54 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) 55 return config.min_lr + coeff * (config.max_lr - config.min_lr) 56 57 58@dataclass(frozen=True, kw_only=True) 59class ValStats: 60 """Validation statistics for logging.""" 61 62 step: int = 0 63 loss_accum: float = 0.0 64 65 def log_record(self) -> LogRecord: 66 """Log record.""" 67 return LogRecord( 68 log_type="val", 69 data={ 70 "step": self.step, 71 "loss": f"{self.loss_accum:0.4f}", 72 }, 73 ) 74 75 76class WorkerState: 77 """State for multi-processing using Distributed Data Parallel.""" 78 79 def __init__(self, device: str) -> None: 80 """Initialize the state.""" 81 # set up DDP (distributed data parallel). 82 # torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE 83 self.ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? 84 if self.ddp and "cuda" not in device: 85 self.ddp = False 86 _LOGGER.warning( 87 "DDP requested but requested device is not cuda, disabling DDP" 88 ) 89 if self.ddp: 90 init_process_group(backend="nccl") 91 self.ddp_rank = int(os.environ["RANK"]) 92 self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) 93 self.ddp_world_size = int(os.environ["WORLD_SIZE"]) 94 self.device = f"cuda:{self.ddp_local_rank}" 95 torch.cuda.set_device(self.device) 96 else: 97 self.ddp_rank = 0 98 self.ddp_local_rank = 0 99 self.ddp_world_size = 1 100 self.device = device 101 102 @property 103 def is_cuda(self) -> bool: 104 """Check if the device is CUDA.""" 105 return "cuda" in self.device 106 107 @property 108 def dtype(self) -> torch.dtype: 109 """Get the data type.""" 110 return get_dtype(self.device) 111 112 @property 113 def is_primary(self) -> bool: 114 """The primary process will do logging, checkpointing, etc.""" 115 return self.ddp_rank == 0 116 117 def __str__(self) -> str: 118 """String representation.""" 119 return f"WorkerState(ddp={self.ddp}, ddp_rank={self.ddp_rank}, ddp_local_rank={self.ddp_local_rank}, ddp_world_size={self.ddp_world_size}, device={self.device})" 120 121 122def compute_loss( 123 model: nn.Module, 124 worker_state: WorkerState, 125 log_label: str, 126 ds: Iterator[tuple[torch.Tensor, torch.Tensor]], 127 steps: int, 128 backward: bool, 129) -> float: 130 """Compute the validation loss. 131 132 It is expected that the model is called in eval mode. 133 This will consume items from the dataset, so it needs 134 to be in the correct state before calling. 135 """ 136 if not steps: 137 raise ValueError("steps must be greater than 0") 138 loss_accum = 0.0 # torch.zeros(1, device=worker_state.device) 139 for step in range(steps): 140 x, y = next(ds) 141 x, y = x.to(worker_state.device), y.to(worker_state.device) 142 if worker_state.ddp: 143 model.require_backward_grad_sync = step == (steps - 1) # type: ignore[assignment] 144 with torch.autocast(device_type=worker_state.device, dtype=worker_state.dtype): 145 logits, loss = model(x, y) 146 loss = loss / steps 147 loss_accum += loss.detach().item() 148 if backward: 149 loss.backward() 150 return loss_accum 151 152 153@dataclass 154class TrainStats: 155 """Training statistics for logging.""" 156 157 def __init__(self, config: TrainConfig) -> None: 158 """Initialize the training statistics.""" 159 self.step = config.step 160 self.t0: float = 0.0 161 self.config = config 162 self.stats: dict[str, Any] = {} 163 164 def start_step(self) -> None: 165 """Start the step.""" 166 self.t0 = time.time() 167 168 def end_step(self, loss: float, norm: float) -> None: 169 """Step the statistics.""" 170 t1 = time.time() 171 dt = (t1 - self.t0) * 1000 172 tok_per_sec = self.config.total_batch_size / (t1 - self.t0) 173 lr = get_lr(self.config, self.step) 174 self.stats.update( 175 { 176 "step": self.step, 177 "loss": f"{loss:0.4f}", 178 "norm": f"{norm:0.4f}", 179 "dt": f"{dt:0.2f}ms", 180 "tok/sec": f"{tok_per_sec:0.2f}", 181 "lr": f"{lr:0.6f}", 182 } 183 ) 184 self.step += 1 185 186 def log_record(self) -> LogRecord: 187 """Log record.""" 188 return LogRecord( 189 log_type="train", 190 data=self.stats, 191 ) 192 193 194def create_optimizer( 195 raw_model: GPT, 196 config: TrainConfig, 197 checkpoint: Checkpoint | None, 198 is_cuda: bool, 199) -> torch.optim.Optimizer: 200 """Create the optimizer with the option to resume from a checkpoint.""" 201 optimizer = raw_model.configure_optimizers( 202 weight_decay=0.1, 203 learning_rate=get_lr(config, 0), 204 use_fused=is_cuda, 205 ) 206 if checkpoint is not None and checkpoint.optimizer_state_dict is not None: 207 _LOGGER.info("Loading optimizer state from checkpoint") 208 optimizer.load_state_dict(checkpoint.optimizer_state_dict) 209 _LOGGER.debug("Optimizer: %s", optimizer.state_dict()) 210 return optimizer 211 212 213def train( 214 raw_model: GPT, 215 optimizer: torch.optim.Optimizer, 216 worker_state: WorkerState, 217 config: TrainConfig, 218 train_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]], 219 eval_config: EvalConfig | None = None, 220 dataset_config: DatasetConfig | None = None, 221 val_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]] | None = None, 222 hellaswag_loader: Iterable[hellaswag.Sample] | None = None, 223 sample_config: SampleConfig | None = None, 224) -> None: 225 """Train the model. 226 227 This is the main training loop. It will train the model for the number of steps 228 specified in the config. It will also evaluate the model on the validation set 229 and save checkpoints. 230 """ 231 config.log_info(worker_state.ddp_world_size) 232 if worker_state.is_primary: 233 log = create_log( 234 pathlib.Path(config.log_file) if config.log_file else None, log_stdout=True 235 ) 236 else: 237 log = create_log(None, False) 238 239 model: nn.Module = raw_model 240 tokenizer = raw_model.enc 241 model.to(worker_state.device) 242 if worker_state.ddp: 243 model = DDP(model, device_ids=[worker_state.ddp_local_rank]) 244 245 train_ds = iter(train_data_loader) 246 stats = TrainStats(config) 247 for step in range(config.step, config.max_steps): 248 last_step = step == config.max_steps - 1 249 stats.start_step() 250 251 val_stats: ValStats | None = None 252 eval_step = step % config.eval_steps == 0 253 checkpoint_step = step % config.checkpoint_steps == 0 254 if ( 255 (eval_step or last_step) 256 and val_data_loader is not None 257 and eval_config is not None 258 and eval_config.validation_steps 259 ): 260 model.eval() 261 val_ds = iter(val_data_loader) 262 with torch.no_grad(): 263 val_loss_accum = compute_loss( 264 model, 265 worker_state, 266 "val", 267 val_ds, 268 eval_config.validation_steps, 269 backward=False, 270 ) 271 if worker_state.ddp: 272 vall_loss_tensor = torch.tensor( 273 val_loss_accum, device=worker_state.device 274 ) 275 dist.all_reduce(vall_loss_tensor, op=dist.ReduceOp.AVG) 276 val_stats = ValStats(step=step, loss_accum=val_loss_accum) 277 if worker_state.is_primary: 278 log.log(val_stats.log_record()) 279 280 if ( 281 step != 0 282 and (step != config.step) # don't save the initial checkpoint 283 and (checkpoint_step or last_step) 284 and worker_state.is_primary 285 and config.checkpoint_dir is not None 286 ): 287 checkpoint_path = ( 288 pathlib.Path(config.checkpoint_dir) / f"checkpoint_{step:06d}.bin" 289 ) 290 checkpoint: Checkpoint = Checkpoint( 291 model_state_dict=raw_model.state_dict(), 292 config=raw_model.config, 293 step=step, 294 val_loss_accum=( 295 val_stats.loss_accum if val_stats is not None else None 296 ), 297 optimizer_state_dict=optimizer.state_dict(), 298 train_config=config, 299 dataset_config=dataset_config, 300 eval_config=eval_config, 301 sample_config=sample_config, 302 ) 303 save_checkpoint(checkpoint, checkpoint_path) 304 if ( 305 step != 0 306 and (eval_step or last_step) 307 and hellaswag_loader is not None 308 and eval_config is not None 309 and eval_config.hellaswag_samples 310 ): 311 model.eval() 312 with torch.no_grad(): 313 hellaswag_result = hellaswag_eval.evaluate( 314 model, 315 tokenizer, 316 hellaswag_loader, 317 worker_state.device, 318 eval_config.hellaswag_samples, 319 ) 320 if worker_state.ddp: 321 num_total = torch.tensor( 322 hellaswag_result.total, 323 dtype=torch.long, 324 device=worker_state.device, 325 ) 326 num_correct_norm = torch.tensor( 327 hellaswag_result.correct, 328 dtype=torch.long, 329 device=worker_state.device, 330 ) 331 dist.all_reduce(num_total, op=dist.ReduceOp.SUM) 332 dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM) 333 hellaswag_result = dataclasses.replace( 334 hellaswag_result, 335 total=int(num_total.item()), 336 correct=int(num_correct_norm.item()), 337 ) 338 if worker_state.is_primary: 339 log.log(hellaswag_result.log_record()) 340 if ( 341 step > 0 342 and eval_step 343 and sample_config is not None 344 and sample_config.num_return_sequences 345 ): 346 model.eval() 347 with torch.no_grad(): 348 samples = sample( 349 model, 350 tokenizer, 351 sample_config.text, 352 num_return_sequences=sample_config.num_return_sequences, 353 max_length=sample_config.max_length, 354 device=worker_state.device, 355 seed=sample_config.seed + worker_state.ddp_rank, 356 ) 357 for i, s in enumerate(samples): 358 print(f"rank {worker_state.ddp_rank} sample {i}: {s}") 359 360 model.train() 361 optimizer.zero_grad() 362 loss_accum = compute_loss( 363 model, 364 worker_state, 365 "train", 366 train_ds, 367 config.grad_accum_steps(worker_state.ddp_world_size), 368 backward=True, 369 ) 370 if worker_state.ddp: 371 loss_tensor = torch.tensor(loss_accum, device=worker_state.device) 372 dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG) 373 374 # Prevent the model from getting large shocks of gradient magnitude 375 norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 376 377 # Update the learning rate based on the step 378 lr = get_lr(config, step) 379 for param_group in optimizer.param_groups: 380 param_group["lr"] = lr 381 optimizer.step() 382 if worker_state.is_cuda: 383 torch.cuda.synchronize() 384 385 stats.end_step(loss_accum, norm.item()) 386 if worker_state.is_primary: 387 log.log(stats.log_record())
def
train( raw_model: nano_gpt.model.GPT, optimizer: torch.optim.optimizer.Optimizer, worker_state: nano_gpt.trainer.WorkerState, config: nano_gpt.config.TrainConfig, train_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]], eval_config: nano_gpt.config.EvalConfig | None = None, dataset_config: nano_gpt.config.DatasetConfig | None = None, val_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]] | None = None, hellaswag_loader: Iterable[nano_gpt.datasets.hellaswag.Sample] | None = None, sample_config: nano_gpt.config.SampleConfig | None = None) -> None:
214def train( 215 raw_model: GPT, 216 optimizer: torch.optim.Optimizer, 217 worker_state: WorkerState, 218 config: TrainConfig, 219 train_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]], 220 eval_config: EvalConfig | None = None, 221 dataset_config: DatasetConfig | None = None, 222 val_data_loader: Iterable[tuple[torch.Tensor, torch.Tensor]] | None = None, 223 hellaswag_loader: Iterable[hellaswag.Sample] | None = None, 224 sample_config: SampleConfig | None = None, 225) -> None: 226 """Train the model. 227 228 This is the main training loop. It will train the model for the number of steps 229 specified in the config. It will also evaluate the model on the validation set 230 and save checkpoints. 231 """ 232 config.log_info(worker_state.ddp_world_size) 233 if worker_state.is_primary: 234 log = create_log( 235 pathlib.Path(config.log_file) if config.log_file else None, log_stdout=True 236 ) 237 else: 238 log = create_log(None, False) 239 240 model: nn.Module = raw_model 241 tokenizer = raw_model.enc 242 model.to(worker_state.device) 243 if worker_state.ddp: 244 model = DDP(model, device_ids=[worker_state.ddp_local_rank]) 245 246 train_ds = iter(train_data_loader) 247 stats = TrainStats(config) 248 for step in range(config.step, config.max_steps): 249 last_step = step == config.max_steps - 1 250 stats.start_step() 251 252 val_stats: ValStats | None = None 253 eval_step = step % config.eval_steps == 0 254 checkpoint_step = step % config.checkpoint_steps == 0 255 if ( 256 (eval_step or last_step) 257 and val_data_loader is not None 258 and eval_config is not None 259 and eval_config.validation_steps 260 ): 261 model.eval() 262 val_ds = iter(val_data_loader) 263 with torch.no_grad(): 264 val_loss_accum = compute_loss( 265 model, 266 worker_state, 267 "val", 268 val_ds, 269 eval_config.validation_steps, 270 backward=False, 271 ) 272 if worker_state.ddp: 273 vall_loss_tensor = torch.tensor( 274 val_loss_accum, device=worker_state.device 275 ) 276 dist.all_reduce(vall_loss_tensor, op=dist.ReduceOp.AVG) 277 val_stats = ValStats(step=step, loss_accum=val_loss_accum) 278 if worker_state.is_primary: 279 log.log(val_stats.log_record()) 280 281 if ( 282 step != 0 283 and (step != config.step) # don't save the initial checkpoint 284 and (checkpoint_step or last_step) 285 and worker_state.is_primary 286 and config.checkpoint_dir is not None 287 ): 288 checkpoint_path = ( 289 pathlib.Path(config.checkpoint_dir) / f"checkpoint_{step:06d}.bin" 290 ) 291 checkpoint: Checkpoint = Checkpoint( 292 model_state_dict=raw_model.state_dict(), 293 config=raw_model.config, 294 step=step, 295 val_loss_accum=( 296 val_stats.loss_accum if val_stats is not None else None 297 ), 298 optimizer_state_dict=optimizer.state_dict(), 299 train_config=config, 300 dataset_config=dataset_config, 301 eval_config=eval_config, 302 sample_config=sample_config, 303 ) 304 save_checkpoint(checkpoint, checkpoint_path) 305 if ( 306 step != 0 307 and (eval_step or last_step) 308 and hellaswag_loader is not None 309 and eval_config is not None 310 and eval_config.hellaswag_samples 311 ): 312 model.eval() 313 with torch.no_grad(): 314 hellaswag_result = hellaswag_eval.evaluate( 315 model, 316 tokenizer, 317 hellaswag_loader, 318 worker_state.device, 319 eval_config.hellaswag_samples, 320 ) 321 if worker_state.ddp: 322 num_total = torch.tensor( 323 hellaswag_result.total, 324 dtype=torch.long, 325 device=worker_state.device, 326 ) 327 num_correct_norm = torch.tensor( 328 hellaswag_result.correct, 329 dtype=torch.long, 330 device=worker_state.device, 331 ) 332 dist.all_reduce(num_total, op=dist.ReduceOp.SUM) 333 dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM) 334 hellaswag_result = dataclasses.replace( 335 hellaswag_result, 336 total=int(num_total.item()), 337 correct=int(num_correct_norm.item()), 338 ) 339 if worker_state.is_primary: 340 log.log(hellaswag_result.log_record()) 341 if ( 342 step > 0 343 and eval_step 344 and sample_config is not None 345 and sample_config.num_return_sequences 346 ): 347 model.eval() 348 with torch.no_grad(): 349 samples = sample( 350 model, 351 tokenizer, 352 sample_config.text, 353 num_return_sequences=sample_config.num_return_sequences, 354 max_length=sample_config.max_length, 355 device=worker_state.device, 356 seed=sample_config.seed + worker_state.ddp_rank, 357 ) 358 for i, s in enumerate(samples): 359 print(f"rank {worker_state.ddp_rank} sample {i}: {s}") 360 361 model.train() 362 optimizer.zero_grad() 363 loss_accum = compute_loss( 364 model, 365 worker_state, 366 "train", 367 train_ds, 368 config.grad_accum_steps(worker_state.ddp_world_size), 369 backward=True, 370 ) 371 if worker_state.ddp: 372 loss_tensor = torch.tensor(loss_accum, device=worker_state.device) 373 dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG) 374 375 # Prevent the model from getting large shocks of gradient magnitude 376 norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 377 378 # Update the learning rate based on the step 379 lr = get_lr(config, step) 380 for param_group in optimizer.param_groups: 381 param_group["lr"] = lr 382 optimizer.step() 383 if worker_state.is_cuda: 384 torch.cuda.synchronize() 385 386 stats.end_step(loss_accum, norm.item()) 387 if worker_state.is_primary: 388 log.log(stats.log_record())
Train the model.
This is the main training loop. It will train the model for the number of steps specified in the config. It will also evaluate the model on the validation set and save checkpoints.
def
create_optimizer( raw_model: nano_gpt.model.GPT, config: nano_gpt.config.TrainConfig, checkpoint: nano_gpt.checkpoint.Checkpoint | None, is_cuda: bool) -> torch.optim.optimizer.Optimizer:
195def create_optimizer( 196 raw_model: GPT, 197 config: TrainConfig, 198 checkpoint: Checkpoint | None, 199 is_cuda: bool, 200) -> torch.optim.Optimizer: 201 """Create the optimizer with the option to resume from a checkpoint.""" 202 optimizer = raw_model.configure_optimizers( 203 weight_decay=0.1, 204 learning_rate=get_lr(config, 0), 205 use_fused=is_cuda, 206 ) 207 if checkpoint is not None and checkpoint.optimizer_state_dict is not None: 208 _LOGGER.info("Loading optimizer state from checkpoint") 209 optimizer.load_state_dict(checkpoint.optimizer_state_dict) 210 _LOGGER.debug("Optimizer: %s", optimizer.state_dict()) 211 return optimizer
Create the optimizer with the option to resume from a checkpoint.