diff --git a/src/levanter/models/gemma.py b/src/levanter/models/gemma.py index 1f8396b20..af5cc44be 100644 --- a/src/levanter/models/gemma.py +++ b/src/levanter/models/gemma.py @@ -28,6 +28,7 @@ LlamaMlp, ) from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.models.rotary import DefaultRotaryEmbeddingsConfig, RotaryEmbeddingsConfig from levanter.types import BlockFoldable from levanter.utils.flop_utils import lm_flops_per_token @@ -80,7 +81,6 @@ class GemmaConfig(HFCompatConfig): attn_dropout = 0.0 norm_eps = 1e-6 - rope_base: int = 10_000 norm_embeddings: bool = True # Attention-related config @@ -94,9 +94,12 @@ class GemmaConfig(HFCompatConfig): scan_layers: bool = True use_bias: bool = False - rope_scaling: Optional[dict] = None rope_theta: float = 10000.0 + @property + def rope(self) -> RotaryEmbeddingsConfig: + return DefaultRotaryEmbeddingsConfig(theta=self.rope_theta) + # Axis Pos = property(lambda self: Axis(name="position", size=self.seq_len)) KeyPos = property(lambda self: self.Pos.alias("key_position")) @@ -146,7 +149,7 @@ def from_hf_config(cls, hf_config: HfConfig): num_kv_heads=hf_config.num_key_value_heads, initializer_range=hf_config.initializer_range, layer_norm_epsilon=hf_config.rms_norm_eps, - rope_base=hf_config.rope_theta, + rope_theta=hf_config.rope_theta, ) def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfGemmaConfig: diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 2a2d2664d..e777b7636 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -1,9 +1,8 @@ import dataclasses from dataclasses import dataclass -from typing import Callable, Dict, Optional, Tuple, Type, Union +from typing import Callable, Dict, Optional, Type, Union import equinox as eqx -import jax import jax.numpy as jnp import jax.random as jrandom from jaxtyping import PRNGKeyArray @@ -28,6 +27,7 @@ from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention from levanter.models.gpt2 import ACT2FN from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.models.rotary import DefaultRotaryEmbeddingsConfig, RotaryEmbeddingsConfig from levanter.types import BlockFoldable from levanter.utils.flop_utils import lm_flops_per_token @@ -77,8 +77,7 @@ class LlamaConfig(HFCompatConfig): use_bias: bool = False use_layer_norm_weight: bool = True - rope_scaling: Optional[dict] = None - rope_theta: float = 10000.0 + rope: RotaryEmbeddingsConfig = dataclasses.field(default_factory=DefaultRotaryEmbeddingsConfig) reference_checkpoint: str = "meta-llama/Llama-2-7b-hf" tokenizer: Optional[str] = None @@ -109,6 +108,8 @@ def hf_checkpoint_converter(self) -> HFCheckpointConverter["LlamaConfig"]: # ty @classmethod def from_hf_config(cls, hf_config: HfConfig): + rope_theta = hf_config.rope_theta + rope_config = RotaryEmbeddingsConfig.from_hf_config(rope_theta, hf_config.rope_scaling) return LlamaConfig( seq_len=hf_config.max_position_embeddings, hidden_dim=hf_config.hidden_size, @@ -119,8 +120,7 @@ def from_hf_config(cls, hf_config: HfConfig): activation_function=hf_config.hidden_act, initializer_range=hf_config.initializer_range, layer_norm_epsilon=hf_config.rms_norm_eps, - rope_scaling=hf_config.rope_scaling, - rope_theta=hf_config.rope_theta, + rope=rope_config, ) def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfLlamaConfig: @@ -136,6 +136,8 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) if config_overrides is None: config_overrides = {} + rope_theta, rope_scaling = self.rope.to_hf_config() + return HfLlamaConfig( max_position_embeddings=self.seq_len, hidden_size=self.hidden_dim, @@ -146,9 +148,10 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) hidden_act=self.activation_function, initializer_range=self.initializer_range, rms_norm_eps=self.layer_norm_epsilon, - rope_scaling=self.rope_scaling, + # rope_scaling=self.rope_scaling, vocab_size=vocab_size, - rope_theta=self.rope_theta, + rope_theta=rope_theta, + rope_scaling=rope_scaling, **config_overrides, ) @@ -274,13 +277,6 @@ def init(config: LlamaConfig, *, key) -> "LlamaAttention": ) return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj) - def _rope_scale_factor(self) -> float: - # hasattr for gemma and I'm feeling lazy - if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: - assert self.config.rope_scaling["type"] == "linear" - return self.config.rope_scaling["factor"] - return 1.0 - @named_call def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, key=None) -> NamedArray: key_q, key_k, key_v, key_o = maybe_rng_split(key, 4) @@ -290,13 +286,8 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, k = self.k_proj(x, key=key_k).rearrange((..., "kv_heads", "position", "head_size")) v = self.v_proj(x, key=key_v).rearrange((..., "kv_heads", "position", "head_size")) - cos, sin = llama_rotary_pos_emb( - self.config.HeadSize, - x.resolve_axis("position"), - scale=self._rope_scale_factor(), - theta=self.config.rope_theta, - ) - q, k = _apply_rotary_pos_emb(q, k, cos, sin) + rot_embs = self.config.rope.build(self.config.HeadSize, q.resolve_axis("position")) + q, k = rot_embs(self.config.HeadSize, q, k) k = k.rename({"position": "key_position"}) v = v.rename({"position": "key_position"}) @@ -588,43 +579,3 @@ def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) state_dict.update(my_dict) return state_dict - - -def _rotate_half(x: NamedArray) -> NamedArray: - """Rotates half of the hidden dims of the input and concatenates them.""" - HeadSize = x.axes[-1] - x1 = x[HeadSize, : HeadSize.size // 2] - x2 = x[HeadSize, HeadSize.size // 2 :] - out = hax.concatenate(HeadSize, (-x2, x1)) - return out - - -def _apply_rotary_pos_emb( - q: NamedArray, # [batch, position, kv_heads, q_heads_per_group, head_size] - k: NamedArray, # [batch, position, kv_heads, head_size] - cos: NamedArray, # [position, head_size] - sin: NamedArray, # [position, head_size] -) -> Tuple[NamedArray, NamedArray]: - """Applies rotary position embedding to q and k.""" - q_embed = q * cos + _rotate_half(q) * sin - k_embed = k * cos + _rotate_half(k) * sin - return q_embed, k_embed - - -def llama_rotary_pos_emb( - HeadSize: Axis, Pos: Axis, theta: float = 10000, scale: float = 1.0 -) -> Tuple[NamedArray, NamedArray]: - with jax.ensure_compile_time_eval(): - HeadHalfSize = HeadSize.resize(HeadSize.size // 2) - inv_freq: NamedArray = 1.0 / (theta ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) - - position_ids: NamedArray = hax.arange(Pos) / scale - - freqs = position_ids * inv_freq.broadcast_axis(Pos) - # This is different from the paper but aligns with HF implementation: - # It uses a different permutation in order to obtain the same calculation - emb = hax.concatenate(HeadSize, (freqs, freqs)) - cos = hax.cos(emb) - sin = hax.sin(emb) - # This is different from the paper but aligns with HF implementation: - return cos, sin diff --git a/src/levanter/models/rotary.py b/src/levanter/models/rotary.py new file mode 100644 index 000000000..07657e5ff --- /dev/null +++ b/src/levanter/models/rotary.py @@ -0,0 +1,182 @@ +import abc +from dataclasses import dataclass +from typing import Tuple + +import draccus +import equinox as eqx +import jax +import jax.numpy as jnp + +import haliax as hax +from haliax import Axis, NamedArray + + +def _rotate_half(x: NamedArray, HeadSize: Axis) -> NamedArray: + """Rotates half of the hidden dims of the input and concatenates them.""" + x1 = x[HeadSize, : HeadSize.size // 2] + x2 = x[HeadSize, HeadSize.size // 2 :] + out = hax.concatenate(HeadSize, (-x2, x1)) + return out + + +class RotaryEmbeddings(eqx.Module): + cos: NamedArray + sin: NamedArray + + @property + def nograd_cos(self): + return jax.lax.stop_gradient(self.cos) + + @property + def nograd_sin(self): + return jax.lax.stop_gradient(self.sin) + + def __call__(self, HeadDim: Axis, q: NamedArray, k: NamedArray) -> tuple[NamedArray, NamedArray]: + q_embed = q * self.nograd_cos + _rotate_half(q, HeadDim) * self.nograd_sin + k_embed = k * self.nograd_cos + _rotate_half(k, HeadDim) * self.nograd_sin + return q_embed, k_embed + + +@dataclass +class RotaryEmbeddingsConfig(abc.ABC, draccus.ChoiceRegistry): + @abc.abstractmethod + def build(self, HeadSize: Axis, Pos: Axis) -> RotaryEmbeddings: + pass + + @staticmethod + def from_hf_config(rope_theta, config: dict | None) -> "RotaryEmbeddingsConfig": + if config is None: + return DefaultRotaryEmbeddingsConfig(theta=rope_theta) + tpe = config.get("rope_type") or config.get("type") or "default" + return RotaryEmbeddingsConfig.get_choice_class(tpe).make_from_hf_config(rope_theta, config) + + @classmethod + @abc.abstractmethod + def make_from_hf_config(cls, rope_theta: float, config: dict) -> "RotaryEmbeddingsConfig": + pass + + @abc.abstractmethod + def to_hf_config(self) -> tuple[float, dict | None]: + """Returns the rope_theta and config dict for the HF config.""" + pass + + +@dataclass +class DefaultRotaryEmbeddingsConfig(RotaryEmbeddingsConfig): + theta: float = 10000 + factor: float = 1.0 # this should have been called scale_factor, but for hf compat + + def build(self, HeadSize: Axis, Pos: Axis) -> RotaryEmbeddings: + with jax.ensure_compile_time_eval(): + HeadHalfSize = HeadSize.resize(HeadSize.size // 2) + inv_freq: NamedArray = 1.0 / (self.theta ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) + inv_freq = inv_freq / self.factor + + position_ids: NamedArray = hax.arange(Pos) + + freqs = position_ids * inv_freq.broadcast_axis(Pos) + emb = hax.concatenate(HeadSize, (freqs, freqs)) + cos = hax.cos(emb) + sin = hax.sin(emb) + return RotaryEmbeddings(cos=cos, sin=sin) + + @classmethod + def make_from_hf_config(cls, rope_theta: float, config: dict) -> "RotaryEmbeddingsConfig": + return DefaultRotaryEmbeddingsConfig(theta=rope_theta, factor=config.get("factor", 1.0)) + + def to_hf_config(self) -> tuple[float, dict | None]: + if self.factor == 1.0: + return self.theta, None + return self.theta, {"factor": self.factor} + + +RotaryEmbeddingsConfig.register_subclass("default", DefaultRotaryEmbeddingsConfig) +RotaryEmbeddingsConfig.register_subclass("linear", DefaultRotaryEmbeddingsConfig) + + +@dataclass +class Llama3RotaryEmbeddingsConfig(RotaryEmbeddingsConfig): + """ + To match this from HF: + "rope_scaling": { + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, + """ + + theta: float = 500000 + factor: float = 8.0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 4.0 + original_max_position_embeddings: int = 8192 + + def build(self, HeadSize: Axis, Pos: Axis) -> RotaryEmbeddings: + # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307 + # Porting that to JAX/Haliax: + with jax.ensure_compile_time_eval(): + HeadHalfSize = HeadSize.resize(HeadSize.size // 2) + inv_freq: NamedArray = 1.0 / (self.theta ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) + + old_context_len = self.original_max_position_embeddings + low_freq_wavelen = old_context_len / self.low_freq_factor + high_freq_wavelen = old_context_len / self.high_freq_factor + + wavelen = 2 * jnp.pi / inv_freq + inv_freq_llama = hax.where(wavelen > low_freq_wavelen, inv_freq / self.factor, inv_freq) + smooth_factor = (old_context_len / wavelen - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = hax.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + position_ids: NamedArray = hax.arange(Pos) + + freqs = position_ids * inv_freq_llama.broadcast_axis(Pos) + emb = hax.concatenate(HeadSize, (freqs, freqs)) + cos = hax.cos(emb) + sin = hax.sin(emb) + return RotaryEmbeddings(cos=cos, sin=sin) + + @classmethod + def make_from_hf_config(cls, rope_theta: float, config: dict) -> "RotaryEmbeddingsConfig": + return Llama3RotaryEmbeddingsConfig( + theta=rope_theta, + factor=config.get("factor", 8.0), + low_freq_factor=config.get("low_freq_factor", 1.0), + high_freq_factor=config.get("high_freq_factor", 4.0), + original_max_position_embeddings=config.get("original_max_position_embeddings", 8192), + ) + + def to_hf_config(self) -> tuple[float, dict]: + return self.theta, { + "factor": self.factor, + "low_freq_factor": self.low_freq_factor, + "high_freq_factor": self.high_freq_factor, + "original_max_position_embeddings": self.original_max_position_embeddings, + } + + +RotaryEmbeddingsConfig.register_subclass("llama3", Llama3RotaryEmbeddingsConfig) + + +def rotary_pos_emb( + HeadSize: Axis, Pos: Axis, theta: float = 10000, scale: float = 1.0 +) -> Tuple[NamedArray, NamedArray]: + with jax.ensure_compile_time_eval(): + HeadHalfSize = HeadSize.resize(HeadSize.size // 2) + inv_freq: NamedArray = 1.0 / (theta ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) / scale + + position_ids: NamedArray = hax.arange(Pos) + + freqs = position_ids * inv_freq.broadcast_axis(Pos) + # This is different from the paper but aligns with HF implementation: + # It uses a different permutation in order to obtain the same calculation + emb = hax.concatenate(HeadSize, (freqs, freqs)) + cos = hax.cos(emb) + sin = hax.sin(emb) + # This is different from the paper but aligns with HF implementation: + return cos, sin diff --git a/tests/test_llama.py b/tests/test_llama.py index 4277150fe..2d2b6506f 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -12,9 +12,8 @@ from levanter.models.attention import AttentionMask from levanter.models.llama import LlamaAttention, LlamaConfig, LlamaDecoderLayer, LlamaLMHeadModel, LlamaRMSNorm -from levanter.models.llama import _apply_rotary_pos_emb as levanter_apply_rotary_pos_emb -from levanter.models.llama import _rotate_half as levanter_rotate_half -from levanter.models.llama import llama_rotary_pos_emb +from levanter.models.rotary import DefaultRotaryEmbeddingsConfig, RotaryEmbeddings +from levanter.models.rotary import _rotate_half as levanter_rotate_half from levanter.utils.jax_utils import parameter_count from test_utils import check_load_config, check_model_works_with_seqlen, parameterize_with_configs, skip_if_no_torch @@ -71,7 +70,9 @@ def test_llama_rotary_embedding(): x = random.normal(key, (1, seq_len)) x_torch = torch.from_numpy(np.array(x)) - levanter_output = llama_rotary_pos_emb(HeadSize=HeadSize, Pos=Pos) + levanter_emb = DefaultRotaryEmbeddingsConfig().build(HeadSize=HeadSize, Pos=Pos) + levanter_output = (levanter_emb.cos, levanter_emb.sin) + hf_rope = HFLlamaRotaryEmbedding(dim=hidden_dim, max_position_embeddings=seq_len, device=device) hf_output = hf_rope(x_torch, torch.arange(seq_len).reshape(1, -1)) @@ -106,8 +107,8 @@ def named_array_to_tensor(named_array): k = hax.random.normal(random.PRNGKey(1), (Batch, Pos, Heads, HeadSize)) # Check the output of _rotate_half() from levanter and hf - levanter_out_rf_q = levanter_rotate_half(q) - levanter_out_rf_k = levanter_rotate_half(k) + levanter_out_rf_q = levanter_rotate_half(q, HeadSize) + levanter_out_rf_k = levanter_rotate_half(k, HeadSize) q_tensor = named_array_to_tensor(q).transpose(1, 2) # needed for HF k_tensor = named_array_to_tensor(k).transpose(1, 2) @@ -121,7 +122,9 @@ def named_array_to_tensor(named_array): cos = hax.random.normal(random.PRNGKey(2), (Pos, HeadSize)) sin = hax.random.normal(random.PRNGKey(3), (Pos, HeadSize)) - levanter_out_rope_q, levanter_out_rope_k = levanter_apply_rotary_pos_emb(q, k, cos, sin) + rot = RotaryEmbeddings(cos=cos, sin=sin) + + levanter_out_rope_q, levanter_out_rope_k = rot(HeadSize, q, k) cos_tensor = named_array_to_tensor(cos)[None, :, :] sin_tensor = named_array_to_tensor(sin)[None, :, :] @@ -328,7 +331,6 @@ def _get_llama_config(use_flash=False, num_kv_heads=4, seq_len=128) -> LlamaConf hidden_dim=16, num_heads=4, num_kv_heads=num_kv_heads, - rope_scaling=None, gradient_checkpointing=False, # disable for tests so debugging is easier use_flash_attention=use_flash, flash_attention_block_size=8 if use_flash else None, diff --git a/tests/test_llama3.py b/tests/test_llama3.py index a6f1d67b8..2fae326d1 100644 --- a/tests/test_llama3.py +++ b/tests/test_llama3.py @@ -35,7 +35,13 @@ def get_config(vocab_size=1000): "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 0.00001, - "rope_scaling": null, + "rope_scaling": { + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, "rope_theta": 500000, "tie_word_embeddings": false, "torch_dtype": "bfloat16", @@ -110,3 +116,31 @@ def compute(model, input): torch_out2 = torch_out2.logits[0].detach().cpu().numpy() assert torch_out2.shape == jax_out.shape, f"{torch_out2.shape} != {jax_out.shape}" np.testing.assert_allclose(torch_out2, jax_out, rtol=1e-5, atol=1e-5) + + +@skip_if_no_torch +def test_llama3_rotary_embedding(): + import torch + from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as HFLlamaRotaryEmbedding + + llama_config = get_config() + key = random.PRNGKey(0) + device = "cpu" + + lev_config = LlamaConfig.from_hf_config(llama_config) + HeadSize = lev_config.HeadSize + Pos = lev_config.Pos + seq_len = Pos.size + + x = random.normal(key, (1, seq_len)) + x_torch = torch.from_numpy(np.array(x)) + + levanter_emb = lev_config.rope.build(HeadSize, Pos) + levanter_output = (levanter_emb.cos, levanter_emb.sin) + + hf_rope = HFLlamaRotaryEmbedding(max_position_embeddings=seq_len, device=device, config=llama_config) + hf_output = hf_rope(x_torch, torch.arange(seq_len).reshape(1, -1)) + + for jax_out, torch_out in zip(levanter_output, hf_output): + torch_out = torch_out.numpy() + assert np.isclose(torch_out, np.array(jax_out.array), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}"