-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Rotary Positional Embeddings (RoPE) - part 2 of parallel attention blocks #450
base: main
Are you sure you want to change the base?
Changes from all commits
1f273a0
b01516a
ceda766
aafa9b3
a2a98ba
15c7469
7c21fba
06fe2a9
2d0ab30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,13 +4,16 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import math | ||
|
||
import pytest | ||
import torch | ||
from tests.test_utils import assert_expected | ||
from torch import nn | ||
|
||
from torchmultimodal.modules.layers.position_embedding import ( | ||
BroadcastedPositionEmbedding, | ||
RotaryPositionalEmbeddings, | ||
SinusoidalPositionEmbeddings, | ||
) | ||
|
||
|
@@ -112,3 +115,38 @@ def test_forward(self, data, emb): | |
actual = emb(data) | ||
expected = torch.Size([3, 5]) | ||
assert_expected(actual.shape, expected) | ||
|
||
|
||
def test_rotary_embeddings_math(): | ||
q = ( | ||
torch.tensor([[1, 0], [1, 0]], dtype=torch.float).unsqueeze(0).unsqueeze(0) | ||
) # b h s e | ||
|
||
k = 2 * torch.tensor([[1, 0], [1, 0]], dtype=torch.float).unsqueeze(0).unsqueeze( | ||
0 | ||
) # b h s e | ||
|
||
rotary_embeddings = RotaryPositionalEmbeddings(2, 2, 1) | ||
qr, kr = rotary_embeddings(q, k, 0) | ||
rot0 = torch.tensor([[math.cos(0), -math.sin(0)], [math.sin(0), math.cos(0)]]) | ||
rot1 = torch.tensor([[math.cos(1), -math.sin(1)], [math.sin(1), math.cos(1)]]) | ||
|
||
assert_expected(torch.matmul(rot0, q[..., 0, :].squeeze()), qr[..., 0, :].squeeze()) | ||
assert_expected(torch.matmul(rot1, q[..., 1, :].squeeze()), qr[..., 1, :].squeeze()) | ||
assert_expected(torch.matmul(rot0, k[..., 0, :].squeeze()), kr[..., 0, :].squeeze()) | ||
assert_expected(torch.matmul(rot1, k[..., 1, :].squeeze()), kr[..., 1, :].squeeze()) | ||
|
||
|
||
def test_rotary_embeddings_left_padding(): | ||
q = torch.ones(2, 1, 4, 16, dtype=torch.float) # b h s e | ||
k = 2 * torch.ones(2, 1, 4, 16, dtype=torch.float) # b h s e | ||
rotary_embeddings = RotaryPositionalEmbeddings(16, 32) | ||
|
||
qr, kr = rotary_embeddings(q, k, 0) | ||
qr2, kr2 = rotary_embeddings(q, k, torch.tensor([0, 1])) | ||
|
||
assert_expected(qr[0], qr2[0]) | ||
assert_expected(qr[0, :, 1], qr2[1, :, 0]) | ||
|
||
assert_expected(kr[0], kr2[0]) | ||
assert_expected(kr[0, :, 1], kr2[1, :, 0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we also add a test for updating the cached frequencies? (As far as I can tell this second test is not hitting that block in L262-268, lmk if I'm misunderstanding) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, that's a good idea. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
import itertools | ||
from typing import Tuple | ||
from typing import Tuple, Union | ||
|
||
import torch | ||
from torch import nn, Tensor | ||
|
@@ -169,3 +169,108 @@ def forward(self, t: Tensor) -> Tensor: | |
if self.embed_dim % 2 == 1: | ||
embeddings = nn.functional.pad(embeddings, (0, 1)) | ||
return embeddings | ||
|
||
|
||
class RotaryPositionalEmbeddings(nn.Module): | ||
def __init__( | ||
self, | ||
dim: int, | ||
max_position_embeddings: Union[int, float] = 2048, | ||
ratio: int = 10000, | ||
device: torch.device = None, | ||
): | ||
""" | ||
Implements Rotary Positional Embeddings (RoPE) | ||
proposed in: https://arxiv.org/abs/2104.09864 | ||
|
||
Args | ||
---- | ||
dim : int | ||
Per-head embedding dimension | ||
max_position_embeddings : int | ||
Maximum expected sequence length for the model, if exceeded the cached freqs will be recomputed | ||
ratio: int | ||
The ratio for the geometric progression to compute the rotation angles | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It'd be nice to add more in the docstring on the exact details of these embeddings, e.g. at least the [[cos, -sin], [sin, cos]] matrix and maybe even a small example (like the simple 2D one you wrote for the unit test) |
||
super().__init__() | ||
self.register_buffer( | ||
"freqs", | ||
1.0 | ||
/ ( | ||
ratio | ||
** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim) | ||
), | ||
) | ||
self.compute_freqs_cis(max_position_embeddings) | ||
|
||
def compute_freqs_cis( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Random q: what does cis mean here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's short form for rotation transform technically doing e^(alpha*i) = cos(alpha) + i * sin(alpha), or shortened, cos + i * sin = cis. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should probably add that in the docstring actually, otherwise it's too cryptic. |
||
self, max_position_embeddings: Union[int, float] = 2048 | ||
) -> None: | ||
t = torch.arange( | ||
max_position_embeddings, device=self.freqs.device, dtype=self.freqs.dtype | ||
) | ||
freqs = torch.outer(t, self.freqs).float() | ||
self.max_seq_len_cached = max_position_embeddings | ||
self.register_buffer( | ||
"cached_freqs", | ||
torch.stack( | ||
[ | ||
torch.cos(freqs), | ||
-torch.sin(freqs), | ||
torch.sin(freqs), | ||
torch.cos(freqs), | ||
], | ||
dim=2, | ||
).view(*freqs.shape, 2, 2), | ||
) | ||
|
||
def reshape_for_broadcast( | ||
self, x: torch.Tensor, cur_freqs: torch.Tensor | ||
) -> torch.Tensor: | ||
ndim = x.ndim | ||
assert 1 < ndim | ||
assert cur_freqs.shape[:2] == (x.shape[2], x.shape[-2]) | ||
shape = [d if i == 2 or i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] | ||
return cur_freqs.view(*shape, 2) | ||
|
||
def forward( | ||
self, | ||
q: torch.Tensor, | ||
k: torch.Tensor, | ||
start_pos: Union[int, float, torch.LongTensor], | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Args | ||
---- | ||
q : torch.Tensor | ||
Embedded query tensor, expected size is B x H x S x Eh | ||
k : torch.Tensor | ||
Embedded query tensor, expected size is B x H x S x Eh | ||
start_pos : Union[int, torch.LongTensor] | ||
The starting position of the tokens encoded in q and k. This is important in | ||
kv-caching and left-padding situations, for which the rotation to be applied might | ||
not always be the pre-cached position 0...S. For kv-caching without dynamic batching | ||
start_pos is shared for all the batch. | ||
""" | ||
seq_len = q.shape[2] | ||
q_ = q.float().reshape(*q.shape[:-1], -1, 2) # B H L D/2 2 | ||
k_ = k.float().reshape(*k.shape[:-1], -1, 2) # B H L D/2 2 | ||
|
||
if isinstance(start_pos, int): | ||
if start_pos + seq_len > self.max_seq_len_cached: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some comments here about when the frequencies need to be recomputed might be helpful There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good - offhand should be changing dtype, changing device, and resetting seq len > max_seq_len. |
||
self.compute_freqs_cis(start_pos + seq_len) | ||
cur_freqs = self.cached_freqs[start_pos : start_pos + seq_len] | ||
freqs = self.reshape_for_broadcast(q_, cur_freqs) | ||
else: | ||
max_start_pos = torch.max(start_pos).item() | ||
if max_start_pos + seq_len > self.max_seq_len_cached: | ||
self.compute_freqs_cis(max_start_pos + seq_len) | ||
freqs_idxs = torch.arange(0, seq_len, dtype=torch.long).repeat( | ||
start_pos.shape[0] | ||
).view(-1, seq_len) + start_pos.view(-1, 1) | ||
freqs = self.cached_freqs[freqs_idxs].unsqueeze(1) | ||
|
||
freqs = freqs.float() # 1 1 L D/2 2 2 | ||
q_out = freqs.mul(q_.unsqueeze(-2)).sum(5).flatten(3) | ||
k_out = freqs.mul(k_.unsqueeze(-2)).sum(5).flatten(3) | ||
return q_out.type_as(q).contiguous(), k_out.type_as(k).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we put these unit tests into a class? (Similar to the other tests in this file)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, will do.