nano_gpt.model
This module defines the GPT model architecture.
This is a thin wrapper around the HuggingFace transformers library and uses the approach from the GPT-2/GPT-3 papers.
1"""This module defines the GPT model architecture. 2 3This is a thin wrapper around the HuggingFace transformers library and uses 4the approach from the GPT-2/GPT-3 papers. 5""" 6 7import logging 8from typing import cast, Any 9 10import torch 11import torch.nn as nn 12from torch.nn import functional as F 13from transformers import GPT2LMHeadModel 14 15from .tokenizer import Tokenizer 16from .config import GPTConfig 17 18_LOGGER = logging.getLogger(__name__) 19 20__all__ = [ 21 "GPT", 22 "sample", 23] 24 25# The openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear 26# this means that we have to transpose these weights when we import them 27PRETRAINED_TRANSPOSED_WEIGHTS = [ 28 "attn.c_attn.weight", 29 "attn.c_proj.weight", 30 "mlp.c_fc.weight", 31 "mlp.c_proj.weight", 32] 33 34 35class CausalSelfAttention(nn.Module): 36 """Attention module.""" 37 38 def __init__(self, config: GPTConfig) -> None: 39 """Initialize MLP.""" 40 super().__init__() 41 # Batch of key/query/value projects for all heads 42 self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) 43 # Output projection 44 self.c_proj = nn.Linear(config.n_embd, config.n_embd) 45 self.c_proj.SCALE_INIT = 1 # type: ignore[assignment] 46 # Regularization 47 self.n_head = config.n_head 48 self.n_embed = config.n_embd 49 self.register_buffer( 50 "bias", 51 torch.tril(torch.ones(config.block_size, config.block_size)).view( 52 1, 1, config.block_size, config.block_size 53 ), 54 ) 55 56 def forward(self, x: torch.Tensor) -> torch.Tensor: 57 """Perform inference.""" 58 B, T, C = x.size() 59 # Compute the query, key, value for all heads in the batch. 60 qkv = self.c_attn(x) 61 q, k, v = qkv.split(self.n_embed, dim=2) 62 # Each are (B, nh, T, hs) 63 k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 64 q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 65 v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 66 67 # Compute attention with a fused kernel of fast attention 68 y = F.scaled_dot_product_attention(q, k, v, is_causal=True) 69 70 # Reassemble and concat everything 71 y = y.transpose(1, 2).contiguous().view(B, T, C) 72 73 # Output projection 74 y = self.c_proj(y) 75 return y # type: ignore[no-any-return] 76 77 78class MLP(nn.Module): 79 """Multi-layer perceptron.""" 80 81 def __init__(self, config: GPTConfig) -> None: 82 """Initialize MLP.""" 83 super().__init__() 84 self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) 85 self.gelu = nn.GELU(approximate="tanh") 86 self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) 87 self.c_proj.SCALE_INIT = 1 # type: ignore[assignment] 88 89 def forward(self, x: torch.Tensor) -> torch.Tensor: 90 """Perform inference.""" 91 x = self.c_fc(x) 92 x = self.gelu(x) 93 x = self.c_proj(x) 94 return x 95 96 97class Block(nn.Module): 98 """A transformer block.""" 99 100 def __init__(self, config: GPTConfig) -> None: 101 """Initialize Block.""" 102 super().__init__() 103 self.config = config 104 self.ln_1 = nn.LayerNorm(config.n_embd) 105 self.attn = CausalSelfAttention(config) 106 self.ln_2 = nn.LayerNorm(config.n_embd) 107 self.mlp = MLP(config) 108 109 def forward(self, x: torch.Tensor) -> torch.Tensor: 110 """Perform inference.""" 111 x = x + self.attn(self.ln_1(x)) 112 x = x + self.mlp(self.ln_2(x)) 113 return x 114 115 116class GPT(nn.Module): 117 """This class defines the GPT model.""" 118 119 def __init__(self, config: GPTConfig, tokenizer: Tokenizer) -> None: 120 super().__init__() 121 self.config = config 122 123 self.transformer = nn.ModuleDict( 124 { 125 "wte": nn.Embedding(config.vocab_size, config.n_embd), 126 "wpe": nn.Embedding(config.block_size, config.n_embd), 127 "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 128 "ln_f": nn.LayerNorm(config.n_embd), 129 } 130 ) 131 # Final classifier 132 self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 133 self.enc = tokenizer 134 135 # Share weights for input and output embeddings. This is about 30% of 136 # the model weights. 137 self.transformer.wte.weight = self.lm_head.weight # type: ignore[union-attr] 138 self.apply(self._init_weights) 139 140 def _init_weights(self, module: nn.Module) -> None: 141 """Perform additional weight initialization to match gpt-2.""" 142 std = 0.02 143 if isinstance(module, nn.Linear): 144 if hasattr(module, "SCALE_INIT"): 145 std *= (2 * self.config.n_layer) ** -0.05 146 torch.nn.init.normal_(module.weight, mean=0, std=std) 147 if module.bias is not None: 148 torch.nn.init.zeros_(module.bias) 149 elif isinstance(module, nn.Embedding): 150 torch.nn.init.normal_(module.weight, mean=0, std=std) 151 152 def configure_optimizers( 153 self, 154 weight_decay: float, 155 learning_rate: float, 156 use_fused: bool, 157 ) -> torch.optim.AdamW: 158 """Return the optimizer.""" 159 # start with all params that require grad 160 param_dict = {pn: p for pn, p in self.named_parameters()} 161 param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 162 # create optim optim_groups 163 decay_params = [p for p in param_dict.values() if p.dim() >= 2] 164 nodecay_params = [p for p in param_dict.values() if p.dim() < 2] 165 optim_groups = [ 166 {"params": decay_params, "weight_decay": weight_decay}, 167 {"params": nodecay_params, "weight_decay": 0}, 168 ] 169 num_decay_params = sum(p.numel() for p in decay_params) 170 num_nodecay_params = sum(p.numel() for p in nodecay_params) 171 _LOGGER.info( 172 "num decay_params %s (tensors) / %s (parameters)", 173 len(decay_params), 174 num_decay_params, 175 ) 176 _LOGGER.info( 177 "num nodecay_params %s (tensors) / %s (parameters)", 178 len(nodecay_params), 179 num_nodecay_params, 180 ) 181 _LOGGER.info("Using fused adamw : %s", use_fused) 182 return torch.optim.AdamW( 183 params=optim_groups, lr=3e-4, betas=(0.9, 0.95), eps=1e-8, fused=use_fused 184 ) 185 186 def forward( 187 self, x: torch.Tensor, targets: torch.Tensor | None = None 188 ) -> tuple[torch.Tensor, torch.Tensor | None]: 189 """Perform the forward pass. 190 191 Returns the output values and loss if targets are provided. 192 """ 193 B, T = x.size() 194 if T > self.config.block_size: 195 raise ValueError( 196 f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}" 197 ) 198 # Forward token and positional embeddings 199 pos = torch.arange(0, T, dtype=torch.long, device=x.device) # Shape (T) 200 201 # (T, n_emb) 202 pos_emb = cast(nn.Embedding, self.transformer.wpe)(pos) 203 # (B, T, n_emb) 204 tok_emb = cast(nn.Embedding, self.transformer.wte)(x) 205 x = tok_emb + pos_emb 206 # Forward transformer blocks 207 for block in cast(nn.ModuleList, self.transformer.h): 208 x = block(x) 209 # Forward the final layernorm 210 x = cast(nn.LayerNorm, self.transformer.ln_f)(x) 211 logits = self.lm_head(x) # (B, T, vocab_size) 212 loss = None 213 if targets is not None: 214 loss = F.cross_entropy( 215 # Flatten to (BxT, vocab_size) 216 logits.view(-1, logits.size(-1)), 217 # Flatten to (BxT) 218 targets.view(-1), 219 ) 220 return logits, loss 221 222 @classmethod 223 def from_pretrained( 224 cls, 225 pretrained_model_name_or_path: str, 226 tokenizer: Tokenizer, 227 model_config: GPTConfig, 228 **kwargs: Any, 229 ) -> "GPT": 230 """Load the GPT from the pretrained model.""" 231 model = GPT(model_config, tokenizer=tokenizer) 232 sd = model.state_dict() 233 sd_keys = [ 234 k for k in sd.keys() if not k.endswith(".attn.bias") 235 ] # discard this mask / buffer, not a param 236 237 # init a huggingface/transformers model 238 model_hf = GPT2LMHeadModel.from_pretrained( 239 pretrained_model_name_or_path, **kwargs 240 ) 241 sd_hf = model_hf.state_dict() 242 243 # copy while ensuring all of the parameters are aligned and match in names and shapes 244 sd_keys_hf = [ 245 k for k in sd_hf.keys() if not k.endswith(".attn.masked_bias") 246 ] # ignore these, just a buffer 247 sd_keys_hf = [k for k in sd_keys_hf if not k.endswith(".attn.bias")] 248 # Transpose weights 249 assert len(sd_keys_hf) == len(sd_keys), ( 250 f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 251 ) 252 for k in sd_keys_hf: 253 if any(k.endswith(w) for w in PRETRAINED_TRANSPOSED_WEIGHTS): 254 # special treatment for the Conv1D weights we need to transpose 255 if sd_hf[k].shape[0] != sd[k].shape[1]: 256 raise ValueError( 257 f"mismatched shapes: {sd_hf[k].shape[::-1]} != {sd[k].shape} for key {k}" 258 ) 259 with torch.no_grad(): 260 sd[k].copy_(sd_hf[k].t()) 261 else: 262 # vanilla copy over the other parameters 263 if sd_hf[k].shape != sd[k].shape: 264 raise ValueError( 265 f"mismatched shapes: {sd_hf[k].shape} != {sd[k].shape} for key {k}" 266 ) 267 with torch.no_grad(): 268 sd[k].copy_(sd_hf[k]) 269 270 return model 271 272 273def sample( 274 model: nn.Module, 275 tokenizer: Tokenizer, 276 text: str, 277 num_return_sequences: int, 278 max_length: int, 279 device: str, 280 seed: int = 42, 281) -> list[str]: 282 """Sample from the model from text input.""" 283 tokenized_text = tokenizer.encode(text) 284 tokens = torch.tensor(tokenized_text, dtype=torch.long) # (8, ) 285 # Replicate input tokens 286 tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1) 287 288 # x is (B, T) 289 x = tokens.to(device) 290 291 sample_rng = torch.Generator(device=device) 292 sample_rng.manual_seed(seed) 293 294 # With each loop iteration we'll append a token to the sequence. This is 295 # adding one more column to x each time. 296 while x.size(1) < max_length: 297 logits, _ = model(x) # (B, T, vocab_size) 298 # Take the logits at the last position (next character) and drop the others. 299 # This is correct but inefficient implementation of sampling. 300 # Question: What is T? 301 logits = logits[:, -1, :] # (B, vocab_size) 302 probs = F.softmax(logits, dim=-1) 303 # Do top-k sampling of 50 which is the huggingface default. Get the top 50 304 # probabilities and set all other tokens to probability of zero. This helps 305 # keep the model on track so it doesn't go off the rails as easily. 306 # Both are (5, 50) 307 topk_probs, topk_indicies = torch.topk(probs, 50, dim=-1) 308 # Select a token from the top 5 309 ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1) 310 # Gather corresponding indicidies 311 xcol = torch.gather(topk_indicies, -1, ix) 312 # Append the new character to the sequence (one for each in the batch) 313 x = torch.cat((x, xcol), dim=-1) 314 315 samples = [] 316 for i in range(num_return_sequences): 317 seq_tokens = x[i, :max_length].tolist() 318 decoded = tokenizer.decode(seq_tokens) 319 samples.append(decoded) 320 321 return samples
class
GPT(torch.nn.modules.module.Module):
117class GPT(nn.Module): 118 """This class defines the GPT model.""" 119 120 def __init__(self, config: GPTConfig, tokenizer: Tokenizer) -> None: 121 super().__init__() 122 self.config = config 123 124 self.transformer = nn.ModuleDict( 125 { 126 "wte": nn.Embedding(config.vocab_size, config.n_embd), 127 "wpe": nn.Embedding(config.block_size, config.n_embd), 128 "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 129 "ln_f": nn.LayerNorm(config.n_embd), 130 } 131 ) 132 # Final classifier 133 self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 134 self.enc = tokenizer 135 136 # Share weights for input and output embeddings. This is about 30% of 137 # the model weights. 138 self.transformer.wte.weight = self.lm_head.weight # type: ignore[union-attr] 139 self.apply(self._init_weights) 140 141 def _init_weights(self, module: nn.Module) -> None: 142 """Perform additional weight initialization to match gpt-2.""" 143 std = 0.02 144 if isinstance(module, nn.Linear): 145 if hasattr(module, "SCALE_INIT"): 146 std *= (2 * self.config.n_layer) ** -0.05 147 torch.nn.init.normal_(module.weight, mean=0, std=std) 148 if module.bias is not None: 149 torch.nn.init.zeros_(module.bias) 150 elif isinstance(module, nn.Embedding): 151 torch.nn.init.normal_(module.weight, mean=0, std=std) 152 153 def configure_optimizers( 154 self, 155 weight_decay: float, 156 learning_rate: float, 157 use_fused: bool, 158 ) -> torch.optim.AdamW: 159 """Return the optimizer.""" 160 # start with all params that require grad 161 param_dict = {pn: p for pn, p in self.named_parameters()} 162 param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 163 # create optim optim_groups 164 decay_params = [p for p in param_dict.values() if p.dim() >= 2] 165 nodecay_params = [p for p in param_dict.values() if p.dim() < 2] 166 optim_groups = [ 167 {"params": decay_params, "weight_decay": weight_decay}, 168 {"params": nodecay_params, "weight_decay": 0}, 169 ] 170 num_decay_params = sum(p.numel() for p in decay_params) 171 num_nodecay_params = sum(p.numel() for p in nodecay_params) 172 _LOGGER.info( 173 "num decay_params %s (tensors) / %s (parameters)", 174 len(decay_params), 175 num_decay_params, 176 ) 177 _LOGGER.info( 178 "num nodecay_params %s (tensors) / %s (parameters)", 179 len(nodecay_params), 180 num_nodecay_params, 181 ) 182 _LOGGER.info("Using fused adamw : %s", use_fused) 183 return torch.optim.AdamW( 184 params=optim_groups, lr=3e-4, betas=(0.9, 0.95), eps=1e-8, fused=use_fused 185 ) 186 187 def forward( 188 self, x: torch.Tensor, targets: torch.Tensor | None = None 189 ) -> tuple[torch.Tensor, torch.Tensor | None]: 190 """Perform the forward pass. 191 192 Returns the output values and loss if targets are provided. 193 """ 194 B, T = x.size() 195 if T > self.config.block_size: 196 raise ValueError( 197 f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}" 198 ) 199 # Forward token and positional embeddings 200 pos = torch.arange(0, T, dtype=torch.long, device=x.device) # Shape (T) 201 202 # (T, n_emb) 203 pos_emb = cast(nn.Embedding, self.transformer.wpe)(pos) 204 # (B, T, n_emb) 205 tok_emb = cast(nn.Embedding, self.transformer.wte)(x) 206 x = tok_emb + pos_emb 207 # Forward transformer blocks 208 for block in cast(nn.ModuleList, self.transformer.h): 209 x = block(x) 210 # Forward the final layernorm 211 x = cast(nn.LayerNorm, self.transformer.ln_f)(x) 212 logits = self.lm_head(x) # (B, T, vocab_size) 213 loss = None 214 if targets is not None: 215 loss = F.cross_entropy( 216 # Flatten to (BxT, vocab_size) 217 logits.view(-1, logits.size(-1)), 218 # Flatten to (BxT) 219 targets.view(-1), 220 ) 221 return logits, loss 222 223 @classmethod 224 def from_pretrained( 225 cls, 226 pretrained_model_name_or_path: str, 227 tokenizer: Tokenizer, 228 model_config: GPTConfig, 229 **kwargs: Any, 230 ) -> "GPT": 231 """Load the GPT from the pretrained model.""" 232 model = GPT(model_config, tokenizer=tokenizer) 233 sd = model.state_dict() 234 sd_keys = [ 235 k for k in sd.keys() if not k.endswith(".attn.bias") 236 ] # discard this mask / buffer, not a param 237 238 # init a huggingface/transformers model 239 model_hf = GPT2LMHeadModel.from_pretrained( 240 pretrained_model_name_or_path, **kwargs 241 ) 242 sd_hf = model_hf.state_dict() 243 244 # copy while ensuring all of the parameters are aligned and match in names and shapes 245 sd_keys_hf = [ 246 k for k in sd_hf.keys() if not k.endswith(".attn.masked_bias") 247 ] # ignore these, just a buffer 248 sd_keys_hf = [k for k in sd_keys_hf if not k.endswith(".attn.bias")] 249 # Transpose weights 250 assert len(sd_keys_hf) == len(sd_keys), ( 251 f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 252 ) 253 for k in sd_keys_hf: 254 if any(k.endswith(w) for w in PRETRAINED_TRANSPOSED_WEIGHTS): 255 # special treatment for the Conv1D weights we need to transpose 256 if sd_hf[k].shape[0] != sd[k].shape[1]: 257 raise ValueError( 258 f"mismatched shapes: {sd_hf[k].shape[::-1]} != {sd[k].shape} for key {k}" 259 ) 260 with torch.no_grad(): 261 sd[k].copy_(sd_hf[k].t()) 262 else: 263 # vanilla copy over the other parameters 264 if sd_hf[k].shape != sd[k].shape: 265 raise ValueError( 266 f"mismatched shapes: {sd_hf[k].shape} != {sd[k].shape} for key {k}" 267 ) 268 with torch.no_grad(): 269 sd[k].copy_(sd_hf[k]) 270 271 return model
This class defines the GPT model.
GPT( config: nano_gpt.config.GPTConfig, tokenizer: nano_gpt.tokenizer.Tokenizer)
120 def __init__(self, config: GPTConfig, tokenizer: Tokenizer) -> None: 121 super().__init__() 122 self.config = config 123 124 self.transformer = nn.ModuleDict( 125 { 126 "wte": nn.Embedding(config.vocab_size, config.n_embd), 127 "wpe": nn.Embedding(config.block_size, config.n_embd), 128 "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 129 "ln_f": nn.LayerNorm(config.n_embd), 130 } 131 ) 132 # Final classifier 133 self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 134 self.enc = tokenizer 135 136 # Share weights for input and output embeddings. This is about 30% of 137 # the model weights. 138 self.transformer.wte.weight = self.lm_head.weight # type: ignore[union-attr] 139 self.apply(self._init_weights)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
def
configure_optimizers( self, weight_decay: float, learning_rate: float, use_fused: bool) -> torch.optim.adamw.AdamW:
153 def configure_optimizers( 154 self, 155 weight_decay: float, 156 learning_rate: float, 157 use_fused: bool, 158 ) -> torch.optim.AdamW: 159 """Return the optimizer.""" 160 # start with all params that require grad 161 param_dict = {pn: p for pn, p in self.named_parameters()} 162 param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 163 # create optim optim_groups 164 decay_params = [p for p in param_dict.values() if p.dim() >= 2] 165 nodecay_params = [p for p in param_dict.values() if p.dim() < 2] 166 optim_groups = [ 167 {"params": decay_params, "weight_decay": weight_decay}, 168 {"params": nodecay_params, "weight_decay": 0}, 169 ] 170 num_decay_params = sum(p.numel() for p in decay_params) 171 num_nodecay_params = sum(p.numel() for p in nodecay_params) 172 _LOGGER.info( 173 "num decay_params %s (tensors) / %s (parameters)", 174 len(decay_params), 175 num_decay_params, 176 ) 177 _LOGGER.info( 178 "num nodecay_params %s (tensors) / %s (parameters)", 179 len(nodecay_params), 180 num_nodecay_params, 181 ) 182 _LOGGER.info("Using fused adamw : %s", use_fused) 183 return torch.optim.AdamW( 184 params=optim_groups, lr=3e-4, betas=(0.9, 0.95), eps=1e-8, fused=use_fused 185 )
Return the optimizer.
def
forward( self, x: torch.Tensor, targets: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor | None]:
187 def forward( 188 self, x: torch.Tensor, targets: torch.Tensor | None = None 189 ) -> tuple[torch.Tensor, torch.Tensor | None]: 190 """Perform the forward pass. 191 192 Returns the output values and loss if targets are provided. 193 """ 194 B, T = x.size() 195 if T > self.config.block_size: 196 raise ValueError( 197 f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}" 198 ) 199 # Forward token and positional embeddings 200 pos = torch.arange(0, T, dtype=torch.long, device=x.device) # Shape (T) 201 202 # (T, n_emb) 203 pos_emb = cast(nn.Embedding, self.transformer.wpe)(pos) 204 # (B, T, n_emb) 205 tok_emb = cast(nn.Embedding, self.transformer.wte)(x) 206 x = tok_emb + pos_emb 207 # Forward transformer blocks 208 for block in cast(nn.ModuleList, self.transformer.h): 209 x = block(x) 210 # Forward the final layernorm 211 x = cast(nn.LayerNorm, self.transformer.ln_f)(x) 212 logits = self.lm_head(x) # (B, T, vocab_size) 213 loss = None 214 if targets is not None: 215 loss = F.cross_entropy( 216 # Flatten to (BxT, vocab_size) 217 logits.view(-1, logits.size(-1)), 218 # Flatten to (BxT) 219 targets.view(-1), 220 ) 221 return logits, loss
Perform the forward pass.
Returns the output values and loss if targets are provided.
@classmethod
def
from_pretrained( cls, pretrained_model_name_or_path: str, tokenizer: nano_gpt.tokenizer.Tokenizer, model_config: nano_gpt.config.GPTConfig, **kwargs: Any) -> GPT:
223 @classmethod 224 def from_pretrained( 225 cls, 226 pretrained_model_name_or_path: str, 227 tokenizer: Tokenizer, 228 model_config: GPTConfig, 229 **kwargs: Any, 230 ) -> "GPT": 231 """Load the GPT from the pretrained model.""" 232 model = GPT(model_config, tokenizer=tokenizer) 233 sd = model.state_dict() 234 sd_keys = [ 235 k for k in sd.keys() if not k.endswith(".attn.bias") 236 ] # discard this mask / buffer, not a param 237 238 # init a huggingface/transformers model 239 model_hf = GPT2LMHeadModel.from_pretrained( 240 pretrained_model_name_or_path, **kwargs 241 ) 242 sd_hf = model_hf.state_dict() 243 244 # copy while ensuring all of the parameters are aligned and match in names and shapes 245 sd_keys_hf = [ 246 k for k in sd_hf.keys() if not k.endswith(".attn.masked_bias") 247 ] # ignore these, just a buffer 248 sd_keys_hf = [k for k in sd_keys_hf if not k.endswith(".attn.bias")] 249 # Transpose weights 250 assert len(sd_keys_hf) == len(sd_keys), ( 251 f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 252 ) 253 for k in sd_keys_hf: 254 if any(k.endswith(w) for w in PRETRAINED_TRANSPOSED_WEIGHTS): 255 # special treatment for the Conv1D weights we need to transpose 256 if sd_hf[k].shape[0] != sd[k].shape[1]: 257 raise ValueError( 258 f"mismatched shapes: {sd_hf[k].shape[::-1]} != {sd[k].shape} for key {k}" 259 ) 260 with torch.no_grad(): 261 sd[k].copy_(sd_hf[k].t()) 262 else: 263 # vanilla copy over the other parameters 264 if sd_hf[k].shape != sd[k].shape: 265 raise ValueError( 266 f"mismatched shapes: {sd_hf[k].shape} != {sd[k].shape} for key {k}" 267 ) 268 with torch.no_grad(): 269 sd[k].copy_(sd_hf[k]) 270 271 return model
Load the GPT from the pretrained model.
def
sample( model: torch.nn.modules.module.Module, tokenizer: nano_gpt.tokenizer.Tokenizer, text: str, num_return_sequences: int, max_length: int, device: str, seed: int = 42) -> list[str]:
274def sample( 275 model: nn.Module, 276 tokenizer: Tokenizer, 277 text: str, 278 num_return_sequences: int, 279 max_length: int, 280 device: str, 281 seed: int = 42, 282) -> list[str]: 283 """Sample from the model from text input.""" 284 tokenized_text = tokenizer.encode(text) 285 tokens = torch.tensor(tokenized_text, dtype=torch.long) # (8, ) 286 # Replicate input tokens 287 tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1) 288 289 # x is (B, T) 290 x = tokens.to(device) 291 292 sample_rng = torch.Generator(device=device) 293 sample_rng.manual_seed(seed) 294 295 # With each loop iteration we'll append a token to the sequence. This is 296 # adding one more column to x each time. 297 while x.size(1) < max_length: 298 logits, _ = model(x) # (B, T, vocab_size) 299 # Take the logits at the last position (next character) and drop the others. 300 # This is correct but inefficient implementation of sampling. 301 # Question: What is T? 302 logits = logits[:, -1, :] # (B, vocab_size) 303 probs = F.softmax(logits, dim=-1) 304 # Do top-k sampling of 50 which is the huggingface default. Get the top 50 305 # probabilities and set all other tokens to probability of zero. This helps 306 # keep the model on track so it doesn't go off the rails as easily. 307 # Both are (5, 50) 308 topk_probs, topk_indicies = torch.topk(probs, 50, dim=-1) 309 # Select a token from the top 5 310 ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1) 311 # Gather corresponding indicidies 312 xcol = torch.gather(topk_indicies, -1, ix) 313 # Append the new character to the sequence (one for each in the batch) 314 x = torch.cat((x, xcol), dim=-1) 315 316 samples = [] 317 for i in range(num_return_sequences): 318 seq_tokens = x[i, :max_length].tolist() 319 decoded = tokenizer.decode(seq_tokens) 320 samples.append(decoded) 321 322 return samples
Sample from the model from text input.