nano_gpt.model

This module defines the GPT model architecture.

This is a thin wrapper around the HuggingFace transformers library and uses the approach from the GPT-2/GPT-3 papers.

  1"""This module defines the GPT model architecture.
  2
  3This is a thin wrapper around the HuggingFace transformers library and uses
  4the approach from the GPT-2/GPT-3 papers.
  5"""
  6
  7import logging
  8from typing import cast, Any
  9
 10import torch
 11import torch.nn as nn
 12from torch.nn import functional as F
 13from transformers import GPT2LMHeadModel
 14
 15from .tokenizer import Tokenizer
 16from .config import GPTConfig
 17
 18_LOGGER = logging.getLogger(__name__)
 19
 20__all__ = [
 21    "GPT",
 22    "sample",
 23]
 24
 25# The openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
 26# this means that we have to transpose these weights when we import them
 27PRETRAINED_TRANSPOSED_WEIGHTS = [
 28    "attn.c_attn.weight",
 29    "attn.c_proj.weight",
 30    "mlp.c_fc.weight",
 31    "mlp.c_proj.weight",
 32]
 33
 34
 35class CausalSelfAttention(nn.Module):
 36    """Attention module."""
 37
 38    def __init__(self, config: GPTConfig) -> None:
 39        """Initialize MLP."""
 40        super().__init__()
 41        # Batch of key/query/value projects for all heads
 42        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
 43        # Output projection
 44        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
 45        self.c_proj.SCALE_INIT = 1  # type: ignore[assignment]
 46        # Regularization
 47        self.n_head = config.n_head
 48        self.n_embed = config.n_embd
 49        self.register_buffer(
 50            "bias",
 51            torch.tril(torch.ones(config.block_size, config.block_size)).view(
 52                1, 1, config.block_size, config.block_size
 53            ),
 54        )
 55
 56    def forward(self, x: torch.Tensor) -> torch.Tensor:
 57        """Perform inference."""
 58        B, T, C = x.size()
 59        # Compute the query, key, value for all heads in the batch.
 60        qkv = self.c_attn(x)
 61        q, k, v = qkv.split(self.n_embed, dim=2)
 62        # Each are (B, nh, T, hs)
 63        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
 64        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
 65        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
 66
 67        # Compute attention with a fused kernel of fast attention
 68        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
 69
 70        # Reassemble and concat everything
 71        y = y.transpose(1, 2).contiguous().view(B, T, C)
 72
 73        # Output projection
 74        y = self.c_proj(y)
 75        return y  # type: ignore[no-any-return]
 76
 77
 78class MLP(nn.Module):
 79    """Multi-layer perceptron."""
 80
 81    def __init__(self, config: GPTConfig) -> None:
 82        """Initialize MLP."""
 83        super().__init__()
 84        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
 85        self.gelu = nn.GELU(approximate="tanh")
 86        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
 87        self.c_proj.SCALE_INIT = 1  # type: ignore[assignment]
 88
 89    def forward(self, x: torch.Tensor) -> torch.Tensor:
 90        """Perform inference."""
 91        x = self.c_fc(x)
 92        x = self.gelu(x)
 93        x = self.c_proj(x)
 94        return x
 95
 96
 97class Block(nn.Module):
 98    """A transformer block."""
 99
100    def __init__(self, config: GPTConfig) -> None:
101        """Initialize Block."""
102        super().__init__()
103        self.config = config
104        self.ln_1 = nn.LayerNorm(config.n_embd)
105        self.attn = CausalSelfAttention(config)
106        self.ln_2 = nn.LayerNorm(config.n_embd)
107        self.mlp = MLP(config)
108
109    def forward(self, x: torch.Tensor) -> torch.Tensor:
110        """Perform inference."""
111        x = x + self.attn(self.ln_1(x))
112        x = x + self.mlp(self.ln_2(x))
113        return x
114
115
116class GPT(nn.Module):
117    """This class defines the GPT model."""
118
119    def __init__(self, config: GPTConfig, tokenizer: Tokenizer) -> None:
120        super().__init__()
121        self.config = config
122
123        self.transformer = nn.ModuleDict(
124            {
125                "wte": nn.Embedding(config.vocab_size, config.n_embd),
126                "wpe": nn.Embedding(config.block_size, config.n_embd),
127                "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
128                "ln_f": nn.LayerNorm(config.n_embd),
129            }
130        )
131        # Final classifier
132        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
133        self.enc = tokenizer
134
135        # Share weights for input and output embeddings. This is about 30% of
136        # the model weights.
137        self.transformer.wte.weight = self.lm_head.weight  # type: ignore[union-attr]
138        self.apply(self._init_weights)
139
140    def _init_weights(self, module: nn.Module) -> None:
141        """Perform additional weight initialization to match gpt-2."""
142        std = 0.02
143        if isinstance(module, nn.Linear):
144            if hasattr(module, "SCALE_INIT"):
145                std *= (2 * self.config.n_layer) ** -0.05
146                torch.nn.init.normal_(module.weight, mean=0, std=std)
147            if module.bias is not None:
148                torch.nn.init.zeros_(module.bias)
149        elif isinstance(module, nn.Embedding):
150            torch.nn.init.normal_(module.weight, mean=0, std=std)
151
152    def configure_optimizers(
153        self,
154        weight_decay: float,
155        learning_rate: float,
156        use_fused: bool,
157    ) -> torch.optim.AdamW:
158        """Return the optimizer."""
159        # start with all params that require grad
160        param_dict = {pn: p for pn, p in self.named_parameters()}
161        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
162        # create optim optim_groups
163        decay_params = [p for p in param_dict.values() if p.dim() >= 2]
164        nodecay_params = [p for p in param_dict.values() if p.dim() < 2]
165        optim_groups = [
166            {"params": decay_params, "weight_decay": weight_decay},
167            {"params": nodecay_params, "weight_decay": 0},
168        ]
169        num_decay_params = sum(p.numel() for p in decay_params)
170        num_nodecay_params = sum(p.numel() for p in nodecay_params)
171        _LOGGER.info(
172            "num decay_params %s (tensors) / %s (parameters)",
173            len(decay_params),
174            num_decay_params,
175        )
176        _LOGGER.info(
177            "num nodecay_params %s (tensors) / %s (parameters)",
178            len(nodecay_params),
179            num_nodecay_params,
180        )
181        _LOGGER.info("Using fused adamw : %s", use_fused)
182        return torch.optim.AdamW(
183            params=optim_groups, lr=3e-4, betas=(0.9, 0.95), eps=1e-8, fused=use_fused
184        )
185
186    def forward(
187        self, x: torch.Tensor, targets: torch.Tensor | None = None
188    ) -> tuple[torch.Tensor, torch.Tensor | None]:
189        """Perform the forward pass.
190
191        Returns the output values and loss if targets are provided.
192        """
193        B, T = x.size()
194        if T > self.config.block_size:
195            raise ValueError(
196                f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
197            )
198        # Forward token and positional embeddings
199        pos = torch.arange(0, T, dtype=torch.long, device=x.device)  # Shape (T)
200
201        # (T, n_emb)
202        pos_emb = cast(nn.Embedding, self.transformer.wpe)(pos)
203        # (B, T, n_emb)
204        tok_emb = cast(nn.Embedding, self.transformer.wte)(x)
205        x = tok_emb + pos_emb
206        # Forward transformer blocks
207        for block in cast(nn.ModuleList, self.transformer.h):
208            x = block(x)
209        # Forward the final layernorm
210        x = cast(nn.LayerNorm, self.transformer.ln_f)(x)
211        logits = self.lm_head(x)  # (B, T, vocab_size)
212        loss = None
213        if targets is not None:
214            loss = F.cross_entropy(
215                # Flatten to (BxT, vocab_size)
216                logits.view(-1, logits.size(-1)),
217                # Flatten to (BxT)
218                targets.view(-1),
219            )
220        return logits, loss
221
222    @classmethod
223    def from_pretrained(
224        cls,
225        pretrained_model_name_or_path: str,
226        tokenizer: Tokenizer,
227        model_config: GPTConfig,
228        **kwargs: Any,
229    ) -> "GPT":
230        """Load the GPT from the pretrained model."""
231        model = GPT(model_config, tokenizer=tokenizer)
232        sd = model.state_dict()
233        sd_keys = [
234            k for k in sd.keys() if not k.endswith(".attn.bias")
235        ]  # discard this mask / buffer, not a param
236
237        # init a huggingface/transformers model
238        model_hf = GPT2LMHeadModel.from_pretrained(
239            pretrained_model_name_or_path, **kwargs
240        )
241        sd_hf = model_hf.state_dict()
242
243        # copy while ensuring all of the parameters are aligned and match in names and shapes
244        sd_keys_hf = [
245            k for k in sd_hf.keys() if not k.endswith(".attn.masked_bias")
246        ]  # ignore these, just a buffer
247        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith(".attn.bias")]
248        # Transpose weights
249        assert len(sd_keys_hf) == len(sd_keys), (
250            f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
251        )
252        for k in sd_keys_hf:
253            if any(k.endswith(w) for w in PRETRAINED_TRANSPOSED_WEIGHTS):
254                # special treatment for the Conv1D weights we need to transpose
255                if sd_hf[k].shape[0] != sd[k].shape[1]:
256                    raise ValueError(
257                        f"mismatched shapes: {sd_hf[k].shape[::-1]} != {sd[k].shape} for key {k}"
258                    )
259                with torch.no_grad():
260                    sd[k].copy_(sd_hf[k].t())
261            else:
262                # vanilla copy over the other parameters
263                if sd_hf[k].shape != sd[k].shape:
264                    raise ValueError(
265                        f"mismatched shapes: {sd_hf[k].shape} != {sd[k].shape} for key {k}"
266                    )
267                with torch.no_grad():
268                    sd[k].copy_(sd_hf[k])
269
270        return model
271
272
273def sample(
274    model: nn.Module,
275    tokenizer: Tokenizer,
276    text: str,
277    num_return_sequences: int,
278    max_length: int,
279    device: str,
280    seed: int = 42,
281) -> list[str]:
282    """Sample from the model from text input."""
283    tokenized_text = tokenizer.encode(text)
284    tokens = torch.tensor(tokenized_text, dtype=torch.long)  # (8, )
285    # Replicate input tokens
286    tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
287
288    # x is (B, T)
289    x = tokens.to(device)
290
291    sample_rng = torch.Generator(device=device)
292    sample_rng.manual_seed(seed)
293
294    # With each loop iteration we'll append a token to the sequence. This is
295    # adding one more column to x each time.
296    while x.size(1) < max_length:
297        logits, _ = model(x)  # (B, T, vocab_size)
298        # Take the logits at the last position (next character) and drop the others.
299        # This is correct but inefficient implementation of sampling.
300        # Question: What is T?
301        logits = logits[:, -1, :]  # (B, vocab_size)
302        probs = F.softmax(logits, dim=-1)
303        # Do top-k sampling of 50 which is the huggingface default. Get the top 50
304        # probabilities and set all other tokens to probability of zero. This helps
305        # keep the model on track so it doesn't go off the rails as easily.
306        # Both are (5, 50)
307        topk_probs, topk_indicies = torch.topk(probs, 50, dim=-1)
308        # Select a token from the top 5
309        ix = torch.multinomial(topk_probs, 1, generator=sample_rng)  # (B, 1)
310        # Gather corresponding indicidies
311        xcol = torch.gather(topk_indicies, -1, ix)
312        # Append the new character to the sequence (one for each in the batch)
313        x = torch.cat((x, xcol), dim=-1)
314
315    samples = []
316    for i in range(num_return_sequences):
317        seq_tokens = x[i, :max_length].tolist()
318        decoded = tokenizer.decode(seq_tokens)
319        samples.append(decoded)
320
321    return samples
class GPT(torch.nn.modules.module.Module):
117class GPT(nn.Module):
118    """This class defines the GPT model."""
119
120    def __init__(self, config: GPTConfig, tokenizer: Tokenizer) -> None:
121        super().__init__()
122        self.config = config
123
124        self.transformer = nn.ModuleDict(
125            {
126                "wte": nn.Embedding(config.vocab_size, config.n_embd),
127                "wpe": nn.Embedding(config.block_size, config.n_embd),
128                "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
129                "ln_f": nn.LayerNorm(config.n_embd),
130            }
131        )
132        # Final classifier
133        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
134        self.enc = tokenizer
135
136        # Share weights for input and output embeddings. This is about 30% of
137        # the model weights.
138        self.transformer.wte.weight = self.lm_head.weight  # type: ignore[union-attr]
139        self.apply(self._init_weights)
140
141    def _init_weights(self, module: nn.Module) -> None:
142        """Perform additional weight initialization to match gpt-2."""
143        std = 0.02
144        if isinstance(module, nn.Linear):
145            if hasattr(module, "SCALE_INIT"):
146                std *= (2 * self.config.n_layer) ** -0.05
147                torch.nn.init.normal_(module.weight, mean=0, std=std)
148            if module.bias is not None:
149                torch.nn.init.zeros_(module.bias)
150        elif isinstance(module, nn.Embedding):
151            torch.nn.init.normal_(module.weight, mean=0, std=std)
152
153    def configure_optimizers(
154        self,
155        weight_decay: float,
156        learning_rate: float,
157        use_fused: bool,
158    ) -> torch.optim.AdamW:
159        """Return the optimizer."""
160        # start with all params that require grad
161        param_dict = {pn: p for pn, p in self.named_parameters()}
162        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
163        # create optim optim_groups
164        decay_params = [p for p in param_dict.values() if p.dim() >= 2]
165        nodecay_params = [p for p in param_dict.values() if p.dim() < 2]
166        optim_groups = [
167            {"params": decay_params, "weight_decay": weight_decay},
168            {"params": nodecay_params, "weight_decay": 0},
169        ]
170        num_decay_params = sum(p.numel() for p in decay_params)
171        num_nodecay_params = sum(p.numel() for p in nodecay_params)
172        _LOGGER.info(
173            "num decay_params %s (tensors) / %s (parameters)",
174            len(decay_params),
175            num_decay_params,
176        )
177        _LOGGER.info(
178            "num nodecay_params %s (tensors) / %s (parameters)",
179            len(nodecay_params),
180            num_nodecay_params,
181        )
182        _LOGGER.info("Using fused adamw : %s", use_fused)
183        return torch.optim.AdamW(
184            params=optim_groups, lr=3e-4, betas=(0.9, 0.95), eps=1e-8, fused=use_fused
185        )
186
187    def forward(
188        self, x: torch.Tensor, targets: torch.Tensor | None = None
189    ) -> tuple[torch.Tensor, torch.Tensor | None]:
190        """Perform the forward pass.
191
192        Returns the output values and loss if targets are provided.
193        """
194        B, T = x.size()
195        if T > self.config.block_size:
196            raise ValueError(
197                f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
198            )
199        # Forward token and positional embeddings
200        pos = torch.arange(0, T, dtype=torch.long, device=x.device)  # Shape (T)
201
202        # (T, n_emb)
203        pos_emb = cast(nn.Embedding, self.transformer.wpe)(pos)
204        # (B, T, n_emb)
205        tok_emb = cast(nn.Embedding, self.transformer.wte)(x)
206        x = tok_emb + pos_emb
207        # Forward transformer blocks
208        for block in cast(nn.ModuleList, self.transformer.h):
209            x = block(x)
210        # Forward the final layernorm
211        x = cast(nn.LayerNorm, self.transformer.ln_f)(x)
212        logits = self.lm_head(x)  # (B, T, vocab_size)
213        loss = None
214        if targets is not None:
215            loss = F.cross_entropy(
216                # Flatten to (BxT, vocab_size)
217                logits.view(-1, logits.size(-1)),
218                # Flatten to (BxT)
219                targets.view(-1),
220            )
221        return logits, loss
222
223    @classmethod
224    def from_pretrained(
225        cls,
226        pretrained_model_name_or_path: str,
227        tokenizer: Tokenizer,
228        model_config: GPTConfig,
229        **kwargs: Any,
230    ) -> "GPT":
231        """Load the GPT from the pretrained model."""
232        model = GPT(model_config, tokenizer=tokenizer)
233        sd = model.state_dict()
234        sd_keys = [
235            k for k in sd.keys() if not k.endswith(".attn.bias")
236        ]  # discard this mask / buffer, not a param
237
238        # init a huggingface/transformers model
239        model_hf = GPT2LMHeadModel.from_pretrained(
240            pretrained_model_name_or_path, **kwargs
241        )
242        sd_hf = model_hf.state_dict()
243
244        # copy while ensuring all of the parameters are aligned and match in names and shapes
245        sd_keys_hf = [
246            k for k in sd_hf.keys() if not k.endswith(".attn.masked_bias")
247        ]  # ignore these, just a buffer
248        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith(".attn.bias")]
249        # Transpose weights
250        assert len(sd_keys_hf) == len(sd_keys), (
251            f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
252        )
253        for k in sd_keys_hf:
254            if any(k.endswith(w) for w in PRETRAINED_TRANSPOSED_WEIGHTS):
255                # special treatment for the Conv1D weights we need to transpose
256                if sd_hf[k].shape[0] != sd[k].shape[1]:
257                    raise ValueError(
258                        f"mismatched shapes: {sd_hf[k].shape[::-1]} != {sd[k].shape} for key {k}"
259                    )
260                with torch.no_grad():
261                    sd[k].copy_(sd_hf[k].t())
262            else:
263                # vanilla copy over the other parameters
264                if sd_hf[k].shape != sd[k].shape:
265                    raise ValueError(
266                        f"mismatched shapes: {sd_hf[k].shape} != {sd[k].shape} for key {k}"
267                    )
268                with torch.no_grad():
269                    sd[k].copy_(sd_hf[k])
270
271        return model

This class defines the GPT model.

GPT( config: nano_gpt.config.GPTConfig, tokenizer: nano_gpt.tokenizer.Tokenizer)
120    def __init__(self, config: GPTConfig, tokenizer: Tokenizer) -> None:
121        super().__init__()
122        self.config = config
123
124        self.transformer = nn.ModuleDict(
125            {
126                "wte": nn.Embedding(config.vocab_size, config.n_embd),
127                "wpe": nn.Embedding(config.block_size, config.n_embd),
128                "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
129                "ln_f": nn.LayerNorm(config.n_embd),
130            }
131        )
132        # Final classifier
133        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
134        self.enc = tokenizer
135
136        # Share weights for input and output embeddings. This is about 30% of
137        # the model weights.
138        self.transformer.wte.weight = self.lm_head.weight  # type: ignore[union-attr]
139        self.apply(self._init_weights)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

config
transformer
lm_head
enc
def configure_optimizers( self, weight_decay: float, learning_rate: float, use_fused: bool) -> torch.optim.adamw.AdamW:
153    def configure_optimizers(
154        self,
155        weight_decay: float,
156        learning_rate: float,
157        use_fused: bool,
158    ) -> torch.optim.AdamW:
159        """Return the optimizer."""
160        # start with all params that require grad
161        param_dict = {pn: p for pn, p in self.named_parameters()}
162        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
163        # create optim optim_groups
164        decay_params = [p for p in param_dict.values() if p.dim() >= 2]
165        nodecay_params = [p for p in param_dict.values() if p.dim() < 2]
166        optim_groups = [
167            {"params": decay_params, "weight_decay": weight_decay},
168            {"params": nodecay_params, "weight_decay": 0},
169        ]
170        num_decay_params = sum(p.numel() for p in decay_params)
171        num_nodecay_params = sum(p.numel() for p in nodecay_params)
172        _LOGGER.info(
173            "num decay_params %s (tensors) / %s (parameters)",
174            len(decay_params),
175            num_decay_params,
176        )
177        _LOGGER.info(
178            "num nodecay_params %s (tensors) / %s (parameters)",
179            len(nodecay_params),
180            num_nodecay_params,
181        )
182        _LOGGER.info("Using fused adamw : %s", use_fused)
183        return torch.optim.AdamW(
184            params=optim_groups, lr=3e-4, betas=(0.9, 0.95), eps=1e-8, fused=use_fused
185        )

Return the optimizer.

def forward( self, x: torch.Tensor, targets: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor | None]:
187    def forward(
188        self, x: torch.Tensor, targets: torch.Tensor | None = None
189    ) -> tuple[torch.Tensor, torch.Tensor | None]:
190        """Perform the forward pass.
191
192        Returns the output values and loss if targets are provided.
193        """
194        B, T = x.size()
195        if T > self.config.block_size:
196            raise ValueError(
197                f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
198            )
199        # Forward token and positional embeddings
200        pos = torch.arange(0, T, dtype=torch.long, device=x.device)  # Shape (T)
201
202        # (T, n_emb)
203        pos_emb = cast(nn.Embedding, self.transformer.wpe)(pos)
204        # (B, T, n_emb)
205        tok_emb = cast(nn.Embedding, self.transformer.wte)(x)
206        x = tok_emb + pos_emb
207        # Forward transformer blocks
208        for block in cast(nn.ModuleList, self.transformer.h):
209            x = block(x)
210        # Forward the final layernorm
211        x = cast(nn.LayerNorm, self.transformer.ln_f)(x)
212        logits = self.lm_head(x)  # (B, T, vocab_size)
213        loss = None
214        if targets is not None:
215            loss = F.cross_entropy(
216                # Flatten to (BxT, vocab_size)
217                logits.view(-1, logits.size(-1)),
218                # Flatten to (BxT)
219                targets.view(-1),
220            )
221        return logits, loss

Perform the forward pass.

Returns the output values and loss if targets are provided.

@classmethod
def from_pretrained( cls, pretrained_model_name_or_path: str, tokenizer: nano_gpt.tokenizer.Tokenizer, model_config: nano_gpt.config.GPTConfig, **kwargs: Any) -> GPT:
223    @classmethod
224    def from_pretrained(
225        cls,
226        pretrained_model_name_or_path: str,
227        tokenizer: Tokenizer,
228        model_config: GPTConfig,
229        **kwargs: Any,
230    ) -> "GPT":
231        """Load the GPT from the pretrained model."""
232        model = GPT(model_config, tokenizer=tokenizer)
233        sd = model.state_dict()
234        sd_keys = [
235            k for k in sd.keys() if not k.endswith(".attn.bias")
236        ]  # discard this mask / buffer, not a param
237
238        # init a huggingface/transformers model
239        model_hf = GPT2LMHeadModel.from_pretrained(
240            pretrained_model_name_or_path, **kwargs
241        )
242        sd_hf = model_hf.state_dict()
243
244        # copy while ensuring all of the parameters are aligned and match in names and shapes
245        sd_keys_hf = [
246            k for k in sd_hf.keys() if not k.endswith(".attn.masked_bias")
247        ]  # ignore these, just a buffer
248        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith(".attn.bias")]
249        # Transpose weights
250        assert len(sd_keys_hf) == len(sd_keys), (
251            f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
252        )
253        for k in sd_keys_hf:
254            if any(k.endswith(w) for w in PRETRAINED_TRANSPOSED_WEIGHTS):
255                # special treatment for the Conv1D weights we need to transpose
256                if sd_hf[k].shape[0] != sd[k].shape[1]:
257                    raise ValueError(
258                        f"mismatched shapes: {sd_hf[k].shape[::-1]} != {sd[k].shape} for key {k}"
259                    )
260                with torch.no_grad():
261                    sd[k].copy_(sd_hf[k].t())
262            else:
263                # vanilla copy over the other parameters
264                if sd_hf[k].shape != sd[k].shape:
265                    raise ValueError(
266                        f"mismatched shapes: {sd_hf[k].shape} != {sd[k].shape} for key {k}"
267                    )
268                with torch.no_grad():
269                    sd[k].copy_(sd_hf[k])
270
271        return model

Load the GPT from the pretrained model.

def sample( model: torch.nn.modules.module.Module, tokenizer: nano_gpt.tokenizer.Tokenizer, text: str, num_return_sequences: int, max_length: int, device: str, seed: int = 42) -> list[str]:
274def sample(
275    model: nn.Module,
276    tokenizer: Tokenizer,
277    text: str,
278    num_return_sequences: int,
279    max_length: int,
280    device: str,
281    seed: int = 42,
282) -> list[str]:
283    """Sample from the model from text input."""
284    tokenized_text = tokenizer.encode(text)
285    tokens = torch.tensor(tokenized_text, dtype=torch.long)  # (8, )
286    # Replicate input tokens
287    tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
288
289    # x is (B, T)
290    x = tokens.to(device)
291
292    sample_rng = torch.Generator(device=device)
293    sample_rng.manual_seed(seed)
294
295    # With each loop iteration we'll append a token to the sequence. This is
296    # adding one more column to x each time.
297    while x.size(1) < max_length:
298        logits, _ = model(x)  # (B, T, vocab_size)
299        # Take the logits at the last position (next character) and drop the others.
300        # This is correct but inefficient implementation of sampling.
301        # Question: What is T?
302        logits = logits[:, -1, :]  # (B, vocab_size)
303        probs = F.softmax(logits, dim=-1)
304        # Do top-k sampling of 50 which is the huggingface default. Get the top 50
305        # probabilities and set all other tokens to probability of zero. This helps
306        # keep the model on track so it doesn't go off the rails as easily.
307        # Both are (5, 50)
308        topk_probs, topk_indicies = torch.topk(probs, 50, dim=-1)
309        # Select a token from the top 5
310        ix = torch.multinomial(topk_probs, 1, generator=sample_rng)  # (B, 1)
311        # Gather corresponding indicidies
312        xcol = torch.gather(topk_indicies, -1, ix)
313        # Append the new character to the sequence (one for each in the batch)
314        x = torch.cat((x, xcol), dim=-1)
315
316    samples = []
317    for i in range(num_return_sequences):
318        seq_tokens = x[i, :max_length].tolist()
319        decoded = tokenizer.decode(seq_tokens)
320        samples.append(decoded)
321
322    return samples

Sample from the model from text input.