Skip to content

Commit

Permalink
wip: fixing RoPE not preserving dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
knyazer committed Sep 7, 2024
1 parent dae889d commit 0ed1ee6
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,18 +177,16 @@ def rotate_half(x: Float[Array, "seq_length embedding_size"]):
@staticmethod
def precompute_freqs_cis(
embedding_size: int, end: int, theta: float
) -> Complex[Array, "end half_of_embedding_size"]:
) -> tuple[Float[Array, "end half_of_embedding_size"], Float[Array, "end half_of_embedding_size"]]:
freqs = 1.0 / (
theta
** (jnp.arange(0.0, embedding_size, 2)[jnp.newaxis, :] / embedding_size)
)

t = jnp.arange(float(end))
freqs_outer = jnp.outer(t, freqs)
with jax.numpy_dtype_promotion("standard"):
freqs_cis = jnp.cos(freqs_outer) + jnp.sin(freqs_outer) * 1j

return freqs_cis
return jnp.cos(freqs_outer), jnp.sin(freqs_outer)

@jax.named_scope("eqx.nn.RotaryPositionalEmbedding")
def __call__(
Expand Down Expand Up @@ -218,23 +216,24 @@ def __call__(

with jax.ensure_compile_time_eval():
if embedding_size in internal_rope_embedding_cache:
freqs_cis = internal_rope_embedding_cache[embedding_size]
freqs_cis_seq_len, _ = freqs_cis.shape
freqs_cis_real, freq_cis_imag = internal_rope_embedding_cache[embedding_size]
freqs_cis_seq_len, _ = freqs_cis_real.shape
if seq_len > freqs_cis_seq_len:
freqs_cis = self.precompute_freqs_cis(
freqs_cis_real, freqs_cis_imag = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta
)
internal_rope_embedding_cache[embedding_size] = freqs_cis
internal_rope_embedding_cache[embedding_size] = (freqs_cis_real, freqs_cis_imag)
else:
freqs_cis = freqs_cis[:seq_len]
freqs_cis_real = freqs_cis_real[:seq_len]
freqs_cis_imag = freqs_cis_imag[:seq_len]
else:
freqs_cis = self.precompute_freqs_cis(
freqs_cis_real, freq_cis_imag = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta
)
internal_rope_embedding_cache[embedding_size] = freqs_cis
internal_rope_embedding_cache[embedding_size] = (freqs_cis_real, freqs_cis_imag)

freqs_real = jnp.tile(freqs_cis.real, (1, 2))
freqs_imag = jnp.tile(freqs_cis.imag, (1, 2))
freqs_real = jnp.tile(freqs_cis_real, (1, 2))
freqs_imag = jnp.tile(freqs_cis_imag, (1, 2))

rotate_x = self.rotate_half(x)
x_rope = (x * freqs_real) + (rotate_x * freqs_imag)
Expand Down

0 comments on commit 0ed1ee6

Please sign in to comment.