From 609b216529f78308665effbb79f7fdae56d4558f Mon Sep 17 00:00:00 2001 From: greygame <125168862@qq.com> Date: Thu, 9 Jan 2025 16:11:09 +0800 Subject: [PATCH 1/3] Delete unused Link declaration --- model.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 10 deletions(-) diff --git a/model.py b/model.py index c698f8b601..383ee9de6e 100644 --- a/model.py +++ b/model.py @@ -10,11 +10,38 @@ import math import inspect from dataclasses import dataclass +from typing import Tuple import torch import torch.nn as nn from torch.nn import functional as F + +def precompute_freqs_cis(dim: int, length: int, theta: float = 10000.0): + freqs = 1.0 / theta ** (torch.arange(0, dim, 2).float() / dim) + position_ids = torch.arange(length, dtype=torch.float32) + freqs = torch.outer(position_ids, freqs).unsqueeze(0) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos(), emb.sin() + + +def rotate_hale(x): + # x (B, nh, T, hs) + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_emb(q, k, cos, sin): + device = q.device + + cos = cos.unsqueeze(1).to(device) + sin = sin.unsqueeze(1).to(device) + q_embed = (q * cos) + (rotate_hale(q) * sin) + k_embed = (k * cos) + (rotate_hale(k) * sin) + return q_embed, k_embed + + class LayerNorm(nn.Module): """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ @@ -41,15 +68,16 @@ def __init__(self, config): self.n_head = config.n_head self.n_embd = config.n_embd self.dropout = config.dropout + self.position_rope = config.position_rope # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') if not self.flash: print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") # causal mask to ensure that attention is only applied to the left in the input sequence self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) - .view(1, 1, config.block_size, config.block_size)) + .view(1, 1, config.block_size, config.block_size)) - def forward(self, x): + def forward(self, x, position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim @@ -57,6 +85,9 @@ def forward(self, x): k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + if self.position_rope: + cos, sin = position_embeddings + q, k = apply_rotary_emb(q, k, cos, sin) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: @@ -100,8 +131,8 @@ def __init__(self, config): self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) self.mlp = MLP(config) - def forward(self, x): - x = x + self.attn(self.ln_1(x)) + def forward(self, x, position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None): + x = x + self.attn(self.ln_1(x), position_embeddings) x = x + self.mlp(self.ln_2(x)) return x @@ -114,6 +145,8 @@ class GPTConfig: n_embd: int = 768 dropout: float = 0.0 bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + position_rope = True + rope_theta = 10000 class GPT(nn.Module): @@ -125,11 +158,19 @@ def __init__(self, config): self.transformer = nn.ModuleDict(dict( wte = nn.Embedding(config.vocab_size, config.n_embd), - wpe = nn.Embedding(config.block_size, config.n_embd), drop = nn.Dropout(config.dropout), h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), ln_f = LayerNorm(config.n_embd, bias=config.bias), )) + + if not config.position_rope: + self.transformer.add_module('wpe', nn.Embedding(config.block_size, config.n_embd)) + print(self.transformer['wpe']) + else: + self.position_embeddings = precompute_freqs_cis(config.n_embd // config.n_head, + config.block_size, + config.rope_theta) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # with weight tying when using torch.compile() some warnings get generated: # "UserWarning: functional_call was passed multiple values for tied weights. @@ -145,7 +186,7 @@ def __init__(self, config): torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) # report number of parameters - print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) + print("number of parameters: %.2fM" % (self.get_num_params(False)/1e6,)) def get_num_params(self, non_embedding=True): """ @@ -156,7 +197,7 @@ def get_num_params(self, non_embedding=True): """ n_params = sum(p.numel() for p in self.parameters()) if non_embedding: - n_params -= self.transformer.wpe.weight.numel() + n_params -= self.transformer['wpe'].weight.numel() return n_params def _init_weights(self, module): @@ -175,10 +216,12 @@ def forward(self, idx, targets=None): # forward the GPT model itself tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) - x = self.transformer.drop(tok_emb + pos_emb) + if not self.config.position_rope: + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) + tok_emb = tok_emb + pos_emb + x = self.transformer.drop(tok_emb) for block in self.transformer.h: - x = block(x) + x = block(x, self.position_embeddings) x = self.transformer.ln_f(x) if targets is not None: @@ -328,3 +371,5 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): idx = torch.cat((idx, idx_next), dim=1) return idx + +#%% From 2fbb7eff36243b64121b19df992d686535aa1055 Mon Sep 17 00:00:00 2001 From: greygame <125168862@qq.com> Date: Thu, 9 Jan 2025 17:53:14 +0800 Subject: [PATCH 2/3] add rope --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 383ee9de6e..691e43fe54 100644 --- a/model.py +++ b/model.py @@ -196,7 +196,7 @@ def get_num_params(self, non_embedding=True): params are actually used as weights in the final layer, so we include them. """ n_params = sum(p.numel() for p in self.parameters()) - if non_embedding: + if non_embedding and 'wpe' in self.transformer: n_params -= self.transformer['wpe'].weight.numel() return n_params From 1c5f4bd1cdf3c84a3383ca0472c5c2b323ca82fd Mon Sep 17 00:00:00 2001 From: greygame <125168862@qq.com> Date: Thu, 9 Jan 2025 18:24:44 +0800 Subject: [PATCH 3/3] add rope --- model.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 56 insertions(+), 11 deletions(-) diff --git a/model.py b/model.py index c698f8b601..691e43fe54 100644 --- a/model.py +++ b/model.py @@ -10,11 +10,38 @@ import math import inspect from dataclasses import dataclass +from typing import Tuple import torch import torch.nn as nn from torch.nn import functional as F + +def precompute_freqs_cis(dim: int, length: int, theta: float = 10000.0): + freqs = 1.0 / theta ** (torch.arange(0, dim, 2).float() / dim) + position_ids = torch.arange(length, dtype=torch.float32) + freqs = torch.outer(position_ids, freqs).unsqueeze(0) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos(), emb.sin() + + +def rotate_hale(x): + # x (B, nh, T, hs) + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_emb(q, k, cos, sin): + device = q.device + + cos = cos.unsqueeze(1).to(device) + sin = sin.unsqueeze(1).to(device) + q_embed = (q * cos) + (rotate_hale(q) * sin) + k_embed = (k * cos) + (rotate_hale(k) * sin) + return q_embed, k_embed + + class LayerNorm(nn.Module): """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ @@ -41,15 +68,16 @@ def __init__(self, config): self.n_head = config.n_head self.n_embd = config.n_embd self.dropout = config.dropout + self.position_rope = config.position_rope # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') if not self.flash: print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") # causal mask to ensure that attention is only applied to the left in the input sequence self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) - .view(1, 1, config.block_size, config.block_size)) + .view(1, 1, config.block_size, config.block_size)) - def forward(self, x): + def forward(self, x, position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim @@ -57,6 +85,9 @@ def forward(self, x): k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + if self.position_rope: + cos, sin = position_embeddings + q, k = apply_rotary_emb(q, k, cos, sin) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: @@ -100,8 +131,8 @@ def __init__(self, config): self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) self.mlp = MLP(config) - def forward(self, x): - x = x + self.attn(self.ln_1(x)) + def forward(self, x, position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None): + x = x + self.attn(self.ln_1(x), position_embeddings) x = x + self.mlp(self.ln_2(x)) return x @@ -114,6 +145,8 @@ class GPTConfig: n_embd: int = 768 dropout: float = 0.0 bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + position_rope = True + rope_theta = 10000 class GPT(nn.Module): @@ -125,11 +158,19 @@ def __init__(self, config): self.transformer = nn.ModuleDict(dict( wte = nn.Embedding(config.vocab_size, config.n_embd), - wpe = nn.Embedding(config.block_size, config.n_embd), drop = nn.Dropout(config.dropout), h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), ln_f = LayerNorm(config.n_embd, bias=config.bias), )) + + if not config.position_rope: + self.transformer.add_module('wpe', nn.Embedding(config.block_size, config.n_embd)) + print(self.transformer['wpe']) + else: + self.position_embeddings = precompute_freqs_cis(config.n_embd // config.n_head, + config.block_size, + config.rope_theta) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # with weight tying when using torch.compile() some warnings get generated: # "UserWarning: functional_call was passed multiple values for tied weights. @@ -145,7 +186,7 @@ def __init__(self, config): torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) # report number of parameters - print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) + print("number of parameters: %.2fM" % (self.get_num_params(False)/1e6,)) def get_num_params(self, non_embedding=True): """ @@ -155,8 +196,8 @@ def get_num_params(self, non_embedding=True): params are actually used as weights in the final layer, so we include them. """ n_params = sum(p.numel() for p in self.parameters()) - if non_embedding: - n_params -= self.transformer.wpe.weight.numel() + if non_embedding and 'wpe' in self.transformer: + n_params -= self.transformer['wpe'].weight.numel() return n_params def _init_weights(self, module): @@ -175,10 +216,12 @@ def forward(self, idx, targets=None): # forward the GPT model itself tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) - x = self.transformer.drop(tok_emb + pos_emb) + if not self.config.position_rope: + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) + tok_emb = tok_emb + pos_emb + x = self.transformer.drop(tok_emb) for block in self.transformer.h: - x = block(x) + x = block(x, self.position_embeddings) x = self.transformer.ln_f(x) if targets is not None: @@ -328,3 +371,5 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): idx = torch.cat((idx, idx_next), dim=1) return idx + +#%%