Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rotary #585

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 56 additions & 11 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,38 @@
import math
import inspect
from dataclasses import dataclass
from typing import Tuple

import torch
import torch.nn as nn
from torch.nn import functional as F


def precompute_freqs_cis(dim: int, length: int, theta: float = 10000.0):
freqs = 1.0 / theta ** (torch.arange(0, dim, 2).float() / dim)
position_ids = torch.arange(length, dtype=torch.float32)
freqs = torch.outer(position_ids, freqs).unsqueeze(0)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()


def rotate_hale(x):
# x (B, nh, T, hs)
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_emb(q, k, cos, sin):
device = q.device

cos = cos.unsqueeze(1).to(device)
sin = sin.unsqueeze(1).to(device)
q_embed = (q * cos) + (rotate_hale(q) * sin)
k_embed = (k * cos) + (rotate_hale(k) * sin)
return q_embed, k_embed


class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

Expand All @@ -41,22 +68,26 @@ def __init__(self, config):
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
self.position_rope = config.position_rope
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))

def forward(self, x):
def forward(self, x, position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
if self.position_rope:
cos, sin = position_embeddings
q, k = apply_rotary_emb(q, k, cos, sin)

# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
Expand Down Expand Up @@ -100,8 +131,8 @@ def __init__(self, config):
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)

def forward(self, x):
x = x + self.attn(self.ln_1(x))
def forward(self, x, position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None):
x = x + self.attn(self.ln_1(x), position_embeddings)
x = x + self.mlp(self.ln_2(x))
return x

Expand All @@ -114,6 +145,8 @@ class GPTConfig:
n_embd: int = 768
dropout: float = 0.0
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
position_rope = True
rope_theta = 10000

class GPT(nn.Module):

Expand All @@ -125,11 +158,19 @@ def __init__(self, config):

self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = LayerNorm(config.n_embd, bias=config.bias),
))

if not config.position_rope:
self.transformer.add_module('wpe', nn.Embedding(config.block_size, config.n_embd))
print(self.transformer['wpe'])
else:
self.position_embeddings = precompute_freqs_cis(config.n_embd // config.n_head,
config.block_size,
config.rope_theta)

self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# with weight tying when using torch.compile() some warnings get generated:
# "UserWarning: functional_call was passed multiple values for tied weights.
Expand All @@ -145,7 +186,7 @@ def __init__(self, config):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

# report number of parameters
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
print("number of parameters: %.2fM" % (self.get_num_params(False)/1e6,))

def get_num_params(self, non_embedding=True):
"""
Expand All @@ -155,8 +196,8 @@ def get_num_params(self, non_embedding=True):
params are actually used as weights in the final layer, so we include them.
"""
n_params = sum(p.numel() for p in self.parameters())
if non_embedding:
n_params -= self.transformer.wpe.weight.numel()
if non_embedding and 'wpe' in self.transformer:
n_params -= self.transformer['wpe'].weight.numel()
return n_params

def _init_weights(self, module):
Expand All @@ -175,10 +216,12 @@ def forward(self, idx, targets=None):

# forward the GPT model itself
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
if not self.config.position_rope:
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
tok_emb = tok_emb + pos_emb
x = self.transformer.drop(tok_emb)
for block in self.transformer.h:
x = block(x)
x = block(x, self.position_embeddings)
x = self.transformer.ln_f(x)

if targets is not None:
Expand Down Expand Up @@ -328,3 +371,5 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
idx = torch.cat((idx, idx_next), dim=1)

return idx

#%%