nano_gpt.checkpoint

Utilities for saving and loading checkpoints of the model and training.

  1"""Utilities for saving and loading checkpoints of the model and training."""
  2
  3import dataclasses
  4from dataclasses import dataclass
  5import json
  6import logging
  7import pathlib
  8from typing import Any
  9
 10import torch
 11import safetensors.torch as st
 12
 13from .config import GPTConfig, TrainConfig, DatasetConfig, EvalConfig, SampleConfig
 14from .model import PRETRAINED_TRANSPOSED_WEIGHTS
 15
 16__all__ = [
 17    "Checkpoint",
 18    "save_checkpoint",
 19    "load_checkpoint",
 20]
 21
 22_LOGGER = logging.getLogger(__name__)
 23
 24
 25CHECKPOINT_DIR = pathlib.Path("checkpoints")
 26
 27
 28@dataclass(frozen=True, kw_only=True)
 29class Checkpoint:
 30    """Checkpoint of the model and training state."""
 31
 32    model_state_dict: dict[str, Any]
 33    """State dict of the model."""
 34
 35    config: GPTConfig
 36    """Config of the model."""
 37
 38    step: int | None = None
 39    """Number of steps the model has been trained for."""
 40
 41    val_loss_accum: float | None = None
 42    """Accumulated validation loss."""
 43
 44    optimizer_state_dict: dict[str, Any] | None = None
 45    """State dict of the optimizer."""
 46
 47    train_config: TrainConfig
 48    """Config of the training."""
 49
 50    dataset_config: DatasetConfig | None
 51    """Config of the dataset."""
 52
 53    eval_config: EvalConfig | None
 54    """Config of the evaluation."""
 55
 56    sample_config: SampleConfig | None
 57    """Config of the sampling."""
 58
 59    name: str | None = None
 60    """Name of the checkpoint."""
 61
 62    @property
 63    def model_state_dict_for_inference(self) -> dict[str, Any]:
 64        """Return the model state dict for inference."""
 65        new_state_dict = {}
 66        for k, v in self.model_state_dict.items():
 67            if k.startswith("_orig_mod."):
 68                new_state_dict[k[len("_orig_mod.") :]] = v
 69            else:
 70                new_state_dict[k] = v
 71        return new_state_dict
 72
 73
 74def save_checkpoint(
 75    checkpoint: Checkpoint,
 76    checkpoint_path: pathlib.Path,
 77) -> None:
 78    """Save the model to disk."""
 79    checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
 80    checkpoint_dict = dataclasses.asdict(checkpoint)
 81    _LOGGER.info("Saving model checkpoint to %s", checkpoint_path)
 82    torch.save(checkpoint_dict, str(checkpoint_path))
 83    _LOGGER.debug("Checkpoint saved")
 84
 85
 86def load_checkpoint(
 87    checkpoint_path: pathlib.Path, device: str | None = None
 88) -> Checkpoint:
 89    """Load the model from disk."""
 90    if not checkpoint_path.exists():
 91        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
 92    checkpoint_dict = torch.load(str(checkpoint_path), map_location=device)
 93    train_config_values = checkpoint_dict["train_config"]
 94    # The training starting step gets updated to the step at which the checkpoint was saved
 95    train_config_values["step"] = checkpoint_dict["step"]
 96    return Checkpoint(
 97        name=checkpoint_path.stem,
 98        model_state_dict=checkpoint_dict["model_state_dict"],
 99        config=GPTConfig(**checkpoint_dict["config"]),
100        step=checkpoint_dict["step"],
101        val_loss_accum=checkpoint_dict["val_loss_accum"],
102        optimizer_state_dict=checkpoint_dict["optimizer_state_dict"],
103        train_config=TrainConfig(**train_config_values),
104        dataset_config=DatasetConfig(**checkpoint_dict["dataset_config"]),
105        eval_config=EvalConfig(**checkpoint_dict["eval_config"]),
106        sample_config=SampleConfig(**checkpoint_dict["sample_config"]),
107    )
108
109
110def export_checkpoint(
111    checkpoint: Checkpoint,
112    checkpoint_path: pathlib.Path,
113    export_dir: pathlib.Path,
114) -> None:
115    """Export the checkpoint to safetensors format."""
116    if not export_dir.exists():
117        export_dir.mkdir(parents=True, exist_ok=True)
118    export_path = export_dir / "model.safetensors"
119    config_path = export_dir / "config.json"
120    if export_path.exists():
121        raise FileExistsError(f"Model export already exists: {export_path}")
122    if config_path.exists():
123        raise FileExistsError(f"Config file already exists: {config_path}")
124    _LOGGER.info("Exporting checkpoint to %s", export_path)
125    metadata = {
126        "format": "pt",
127        "model_name": checkpoint.name or "gpt2",
128    }
129    original_state = checkpoint.model_state_dict_for_inference
130    # Put the weights in the format expected by GPT2LMHeadModel. See the
131    # code in model.py which transposes the weights which is the inverse of
132    # this operation.
133    loaded = {}
134    for k, v in original_state.items():
135        if any(k.endswith(w) for w in PRETRAINED_TRANSPOSED_WEIGHTS):
136            with torch.no_grad():
137                loaded[k] = v.t().contiguous()
138        else:
139            loaded[k] = v.contiguous()
140
141    model_config = dataclasses.asdict(checkpoint.config)
142    config = {
143        "model_type": "gpt2",
144        "architectures": ["GPT2LMHeadModel"],
145        "n_ctx": model_config["block_size"],
146        **model_config,
147        "val_loss_accum": checkpoint.val_loss_accum,
148        "train_config": dataclasses.asdict(checkpoint.train_config),
149    }
150    if checkpoint.dataset_config:
151        config["dataset_config"] = dataclasses.asdict(checkpoint.dataset_config)
152    if checkpoint.eval_config or checkpoint.sample_config:
153        config["task_specific_params"] = {}
154        if checkpoint.eval_config:
155            config["task_specific_params"]["eval_config"] = dataclasses.asdict(
156                checkpoint.eval_config
157            )
158        if checkpoint.sample_config:
159            config["task_specific_params"]["sample_config"] = dataclasses.asdict(
160                checkpoint.sample_config
161            )
162
163    st.save_file(loaded, str(export_path), metadata=metadata)
164    config_path.write_text(json.dumps(config, indent=4))
165    _LOGGER.debug("Checkpoint exported")
166
167    _LOGGER.info("Verifying exported checkpoint")
168    pt_size = checkpoint_path.stat().st_size
169    size_diff = abs(pt_size - export_path.stat().st_size)
170    diff_pct = (1.0 * size_diff) / pt_size
171    _LOGGER.info("Exported file size: %d bytes", export_path.stat().st_size)
172    _LOGGER.info("Original file size: %d bytes", pt_size)
173    _LOGGER.info("Difference: %d bytes (%.2f%%)", size_diff, diff_pct * 100)
174
175    reloaded = st.load_file(str(export_path))
176    _LOGGER.info(
177        "Verifying tensors (%d tensors)", len(checkpoint.model_state_dict_for_inference)
178    )
179    for k, pt_tensor in checkpoint.model_state_dict_for_inference.items():
180        _LOGGER.debug("Verifying tensor %s", k)
181        sf_tensor = reloaded[k]
182        if any(k.endswith(w) for w in PRETRAINED_TRANSPOSED_WEIGHTS):
183            if not torch.equal(pt_tensor.t(), sf_tensor):
184                raise RuntimeError(f"The output tensors do not match for key {k}")
185        else:
186            if not torch.equal(pt_tensor, sf_tensor):
187                raise RuntimeError(f"The output tensors do not match for key {k}")
188    _LOGGER.info("All tensors match")
@dataclass(frozen=True, kw_only=True)
class Checkpoint:
29@dataclass(frozen=True, kw_only=True)
30class Checkpoint:
31    """Checkpoint of the model and training state."""
32
33    model_state_dict: dict[str, Any]
34    """State dict of the model."""
35
36    config: GPTConfig
37    """Config of the model."""
38
39    step: int | None = None
40    """Number of steps the model has been trained for."""
41
42    val_loss_accum: float | None = None
43    """Accumulated validation loss."""
44
45    optimizer_state_dict: dict[str, Any] | None = None
46    """State dict of the optimizer."""
47
48    train_config: TrainConfig
49    """Config of the training."""
50
51    dataset_config: DatasetConfig | None
52    """Config of the dataset."""
53
54    eval_config: EvalConfig | None
55    """Config of the evaluation."""
56
57    sample_config: SampleConfig | None
58    """Config of the sampling."""
59
60    name: str | None = None
61    """Name of the checkpoint."""
62
63    @property
64    def model_state_dict_for_inference(self) -> dict[str, Any]:
65        """Return the model state dict for inference."""
66        new_state_dict = {}
67        for k, v in self.model_state_dict.items():
68            if k.startswith("_orig_mod."):
69                new_state_dict[k[len("_orig_mod.") :]] = v
70            else:
71                new_state_dict[k] = v
72        return new_state_dict

Checkpoint of the model and training state.

Checkpoint( *, model_state_dict: dict[str, typing.Any], config: nano_gpt.config.GPTConfig, step: int | None = None, val_loss_accum: float | None = None, optimizer_state_dict: dict[str, Any] | None = None, train_config: nano_gpt.config.TrainConfig, dataset_config: nano_gpt.config.DatasetConfig | None, eval_config: nano_gpt.config.EvalConfig | None, sample_config: nano_gpt.config.SampleConfig | None, name: str | None = None)
model_state_dict: dict[str, typing.Any]

State dict of the model.

Config of the model.

step: int | None = None

Number of steps the model has been trained for.

val_loss_accum: float | None = None

Accumulated validation loss.

optimizer_state_dict: dict[str, Any] | None = None

State dict of the optimizer.

Config of the training.

dataset_config: nano_gpt.config.DatasetConfig | None

Config of the dataset.

eval_config: nano_gpt.config.EvalConfig | None

Config of the evaluation.

sample_config: nano_gpt.config.SampleConfig | None

Config of the sampling.

name: str | None = None

Name of the checkpoint.

model_state_dict_for_inference: dict[str, typing.Any]
63    @property
64    def model_state_dict_for_inference(self) -> dict[str, Any]:
65        """Return the model state dict for inference."""
66        new_state_dict = {}
67        for k, v in self.model_state_dict.items():
68            if k.startswith("_orig_mod."):
69                new_state_dict[k[len("_orig_mod.") :]] = v
70            else:
71                new_state_dict[k] = v
72        return new_state_dict

Return the model state dict for inference.

def save_checkpoint( checkpoint: Checkpoint, checkpoint_path: pathlib.Path) -> None:
75def save_checkpoint(
76    checkpoint: Checkpoint,
77    checkpoint_path: pathlib.Path,
78) -> None:
79    """Save the model to disk."""
80    checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
81    checkpoint_dict = dataclasses.asdict(checkpoint)
82    _LOGGER.info("Saving model checkpoint to %s", checkpoint_path)
83    torch.save(checkpoint_dict, str(checkpoint_path))
84    _LOGGER.debug("Checkpoint saved")

Save the model to disk.

def load_checkpoint( checkpoint_path: pathlib.Path, device: str | None = None) -> Checkpoint:
 87def load_checkpoint(
 88    checkpoint_path: pathlib.Path, device: str | None = None
 89) -> Checkpoint:
 90    """Load the model from disk."""
 91    if not checkpoint_path.exists():
 92        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
 93    checkpoint_dict = torch.load(str(checkpoint_path), map_location=device)
 94    train_config_values = checkpoint_dict["train_config"]
 95    # The training starting step gets updated to the step at which the checkpoint was saved
 96    train_config_values["step"] = checkpoint_dict["step"]
 97    return Checkpoint(
 98        name=checkpoint_path.stem,
 99        model_state_dict=checkpoint_dict["model_state_dict"],
100        config=GPTConfig(**checkpoint_dict["config"]),
101        step=checkpoint_dict["step"],
102        val_loss_accum=checkpoint_dict["val_loss_accum"],
103        optimizer_state_dict=checkpoint_dict["optimizer_state_dict"],
104        train_config=TrainConfig(**train_config_values),
105        dataset_config=DatasetConfig(**checkpoint_dict["dataset_config"]),
106        eval_config=EvalConfig(**checkpoint_dict["eval_config"]),
107        sample_config=SampleConfig(**checkpoint_dict["sample_config"]),
108    )

Load the model from disk.