nano_gpt.checkpoint
Utilities for saving and loading checkpoints of the model and training.
1"""Utilities for saving and loading checkpoints of the model and training.""" 2 3import dataclasses 4from dataclasses import dataclass 5import json 6import logging 7import pathlib 8from typing import Any 9 10import torch 11import safetensors.torch as st 12 13from .config import GPTConfig, TrainConfig, DatasetConfig, EvalConfig, SampleConfig 14from .model import PRETRAINED_TRANSPOSED_WEIGHTS 15 16__all__ = [ 17 "Checkpoint", 18 "save_checkpoint", 19 "load_checkpoint", 20] 21 22_LOGGER = logging.getLogger(__name__) 23 24 25CHECKPOINT_DIR = pathlib.Path("checkpoints") 26 27 28@dataclass(frozen=True, kw_only=True) 29class Checkpoint: 30 """Checkpoint of the model and training state.""" 31 32 model_state_dict: dict[str, Any] 33 """State dict of the model.""" 34 35 config: GPTConfig 36 """Config of the model.""" 37 38 step: int | None = None 39 """Number of steps the model has been trained for.""" 40 41 val_loss_accum: float | None = None 42 """Accumulated validation loss.""" 43 44 optimizer_state_dict: dict[str, Any] | None = None 45 """State dict of the optimizer.""" 46 47 train_config: TrainConfig 48 """Config of the training.""" 49 50 dataset_config: DatasetConfig | None 51 """Config of the dataset.""" 52 53 eval_config: EvalConfig | None 54 """Config of the evaluation.""" 55 56 sample_config: SampleConfig | None 57 """Config of the sampling.""" 58 59 name: str | None = None 60 """Name of the checkpoint.""" 61 62 @property 63 def model_state_dict_for_inference(self) -> dict[str, Any]: 64 """Return the model state dict for inference.""" 65 new_state_dict = {} 66 for k, v in self.model_state_dict.items(): 67 if k.startswith("_orig_mod."): 68 new_state_dict[k[len("_orig_mod.") :]] = v 69 else: 70 new_state_dict[k] = v 71 return new_state_dict 72 73 74def save_checkpoint( 75 checkpoint: Checkpoint, 76 checkpoint_path: pathlib.Path, 77) -> None: 78 """Save the model to disk.""" 79 checkpoint_path.parent.mkdir(parents=True, exist_ok=True) 80 checkpoint_dict = dataclasses.asdict(checkpoint) 81 _LOGGER.info("Saving model checkpoint to %s", checkpoint_path) 82 torch.save(checkpoint_dict, str(checkpoint_path)) 83 _LOGGER.debug("Checkpoint saved") 84 85 86def load_checkpoint( 87 checkpoint_path: pathlib.Path, device: str | None = None 88) -> Checkpoint: 89 """Load the model from disk.""" 90 if not checkpoint_path.exists(): 91 raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") 92 checkpoint_dict = torch.load(str(checkpoint_path), map_location=device) 93 train_config_values = checkpoint_dict["train_config"] 94 # The training starting step gets updated to the step at which the checkpoint was saved 95 train_config_values["step"] = checkpoint_dict["step"] 96 return Checkpoint( 97 name=checkpoint_path.stem, 98 model_state_dict=checkpoint_dict["model_state_dict"], 99 config=GPTConfig(**checkpoint_dict["config"]), 100 step=checkpoint_dict["step"], 101 val_loss_accum=checkpoint_dict["val_loss_accum"], 102 optimizer_state_dict=checkpoint_dict["optimizer_state_dict"], 103 train_config=TrainConfig(**train_config_values), 104 dataset_config=DatasetConfig(**checkpoint_dict["dataset_config"]), 105 eval_config=EvalConfig(**checkpoint_dict["eval_config"]), 106 sample_config=SampleConfig(**checkpoint_dict["sample_config"]), 107 ) 108 109 110def export_checkpoint( 111 checkpoint: Checkpoint, 112 checkpoint_path: pathlib.Path, 113 export_dir: pathlib.Path, 114) -> None: 115 """Export the checkpoint to safetensors format.""" 116 if not export_dir.exists(): 117 export_dir.mkdir(parents=True, exist_ok=True) 118 export_path = export_dir / "model.safetensors" 119 config_path = export_dir / "config.json" 120 if export_path.exists(): 121 raise FileExistsError(f"Model export already exists: {export_path}") 122 if config_path.exists(): 123 raise FileExistsError(f"Config file already exists: {config_path}") 124 _LOGGER.info("Exporting checkpoint to %s", export_path) 125 metadata = { 126 "format": "pt", 127 "model_name": checkpoint.name or "gpt2", 128 } 129 original_state = checkpoint.model_state_dict_for_inference 130 # Put the weights in the format expected by GPT2LMHeadModel. See the 131 # code in model.py which transposes the weights which is the inverse of 132 # this operation. 133 loaded = {} 134 for k, v in original_state.items(): 135 if any(k.endswith(w) for w in PRETRAINED_TRANSPOSED_WEIGHTS): 136 with torch.no_grad(): 137 loaded[k] = v.t().contiguous() 138 else: 139 loaded[k] = v.contiguous() 140 141 model_config = dataclasses.asdict(checkpoint.config) 142 config = { 143 "model_type": "gpt2", 144 "architectures": ["GPT2LMHeadModel"], 145 "n_ctx": model_config["block_size"], 146 **model_config, 147 "val_loss_accum": checkpoint.val_loss_accum, 148 "train_config": dataclasses.asdict(checkpoint.train_config), 149 } 150 if checkpoint.dataset_config: 151 config["dataset_config"] = dataclasses.asdict(checkpoint.dataset_config) 152 if checkpoint.eval_config or checkpoint.sample_config: 153 config["task_specific_params"] = {} 154 if checkpoint.eval_config: 155 config["task_specific_params"]["eval_config"] = dataclasses.asdict( 156 checkpoint.eval_config 157 ) 158 if checkpoint.sample_config: 159 config["task_specific_params"]["sample_config"] = dataclasses.asdict( 160 checkpoint.sample_config 161 ) 162 163 st.save_file(loaded, str(export_path), metadata=metadata) 164 config_path.write_text(json.dumps(config, indent=4)) 165 _LOGGER.debug("Checkpoint exported") 166 167 _LOGGER.info("Verifying exported checkpoint") 168 pt_size = checkpoint_path.stat().st_size 169 size_diff = abs(pt_size - export_path.stat().st_size) 170 diff_pct = (1.0 * size_diff) / pt_size 171 _LOGGER.info("Exported file size: %d bytes", export_path.stat().st_size) 172 _LOGGER.info("Original file size: %d bytes", pt_size) 173 _LOGGER.info("Difference: %d bytes (%.2f%%)", size_diff, diff_pct * 100) 174 175 reloaded = st.load_file(str(export_path)) 176 _LOGGER.info( 177 "Verifying tensors (%d tensors)", len(checkpoint.model_state_dict_for_inference) 178 ) 179 for k, pt_tensor in checkpoint.model_state_dict_for_inference.items(): 180 _LOGGER.debug("Verifying tensor %s", k) 181 sf_tensor = reloaded[k] 182 if any(k.endswith(w) for w in PRETRAINED_TRANSPOSED_WEIGHTS): 183 if not torch.equal(pt_tensor.t(), sf_tensor): 184 raise RuntimeError(f"The output tensors do not match for key {k}") 185 else: 186 if not torch.equal(pt_tensor, sf_tensor): 187 raise RuntimeError(f"The output tensors do not match for key {k}") 188 _LOGGER.info("All tensors match")
@dataclass(frozen=True, kw_only=True)
class
Checkpoint:
29@dataclass(frozen=True, kw_only=True) 30class Checkpoint: 31 """Checkpoint of the model and training state.""" 32 33 model_state_dict: dict[str, Any] 34 """State dict of the model.""" 35 36 config: GPTConfig 37 """Config of the model.""" 38 39 step: int | None = None 40 """Number of steps the model has been trained for.""" 41 42 val_loss_accum: float | None = None 43 """Accumulated validation loss.""" 44 45 optimizer_state_dict: dict[str, Any] | None = None 46 """State dict of the optimizer.""" 47 48 train_config: TrainConfig 49 """Config of the training.""" 50 51 dataset_config: DatasetConfig | None 52 """Config of the dataset.""" 53 54 eval_config: EvalConfig | None 55 """Config of the evaluation.""" 56 57 sample_config: SampleConfig | None 58 """Config of the sampling.""" 59 60 name: str | None = None 61 """Name of the checkpoint.""" 62 63 @property 64 def model_state_dict_for_inference(self) -> dict[str, Any]: 65 """Return the model state dict for inference.""" 66 new_state_dict = {} 67 for k, v in self.model_state_dict.items(): 68 if k.startswith("_orig_mod."): 69 new_state_dict[k[len("_orig_mod.") :]] = v 70 else: 71 new_state_dict[k] = v 72 return new_state_dict
Checkpoint of the model and training state.
Checkpoint( *, model_state_dict: dict[str, typing.Any], config: nano_gpt.config.GPTConfig, step: int | None = None, val_loss_accum: float | None = None, optimizer_state_dict: dict[str, Any] | None = None, train_config: nano_gpt.config.TrainConfig, dataset_config: nano_gpt.config.DatasetConfig | None, eval_config: nano_gpt.config.EvalConfig | None, sample_config: nano_gpt.config.SampleConfig | None, name: str | None = None)
model_state_dict_for_inference: dict[str, typing.Any]
63 @property 64 def model_state_dict_for_inference(self) -> dict[str, Any]: 65 """Return the model state dict for inference.""" 66 new_state_dict = {} 67 for k, v in self.model_state_dict.items(): 68 if k.startswith("_orig_mod."): 69 new_state_dict[k[len("_orig_mod.") :]] = v 70 else: 71 new_state_dict[k] = v 72 return new_state_dict
Return the model state dict for inference.
75def save_checkpoint( 76 checkpoint: Checkpoint, 77 checkpoint_path: pathlib.Path, 78) -> None: 79 """Save the model to disk.""" 80 checkpoint_path.parent.mkdir(parents=True, exist_ok=True) 81 checkpoint_dict = dataclasses.asdict(checkpoint) 82 _LOGGER.info("Saving model checkpoint to %s", checkpoint_path) 83 torch.save(checkpoint_dict, str(checkpoint_path)) 84 _LOGGER.debug("Checkpoint saved")
Save the model to disk.
87def load_checkpoint( 88 checkpoint_path: pathlib.Path, device: str | None = None 89) -> Checkpoint: 90 """Load the model from disk.""" 91 if not checkpoint_path.exists(): 92 raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") 93 checkpoint_dict = torch.load(str(checkpoint_path), map_location=device) 94 train_config_values = checkpoint_dict["train_config"] 95 # The training starting step gets updated to the step at which the checkpoint was saved 96 train_config_values["step"] = checkpoint_dict["step"] 97 return Checkpoint( 98 name=checkpoint_path.stem, 99 model_state_dict=checkpoint_dict["model_state_dict"], 100 config=GPTConfig(**checkpoint_dict["config"]), 101 step=checkpoint_dict["step"], 102 val_loss_accum=checkpoint_dict["val_loss_accum"], 103 optimizer_state_dict=checkpoint_dict["optimizer_state_dict"], 104 train_config=TrainConfig(**train_config_values), 105 dataset_config=DatasetConfig(**checkpoint_dict["dataset_config"]), 106 eval_config=EvalConfig(**checkpoint_dict["eval_config"]), 107 sample_config=SampleConfig(**checkpoint_dict["sample_config"]), 108 )
Load the model from disk.