Skip to content

Commit

Permalink
[Rotary] Fix performance drop for varlen (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Jan 14, 2025
1 parent 43fe3b0 commit 656657a
Showing 1 changed file with 119 additions and 151 deletions.
270 changes: 119 additions & 151 deletions fla/modules/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,140 +20,121 @@ def rotate_half(x, interleaved=False):
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2)


def rotary_embedding_ref(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
x: (N, T, H, D)
cos, sin: (TR, DR / 2) or (N, TR, DR / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
cos = repeat(cos, '... d -> ... 1 (2 d)' if not interleaved else '... d -> ... 1 (d 2)')
sin = repeat(sin, '... d -> ... 1 (2 d)' if not interleaved else '... d -> ... 1 (d 2)')
return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], -1)


@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [2, 4, 8, 16, 32]
triton.Config({'BM': BM}, num_warps=num_warps)
for BM in [4, 8, 16, 32, 64, 128]
for num_warps in [2, 4, 8, 16]
],
key=["BLOCK_K", "BLOCK_M", "INTERLEAVED"],
key=['B', 'T', 'H', 'INTERLEAVED'],
)
@triton.jit
def rotary_embedding_kernel(
X,
COS,
SIN,
OUT,
CU_SEQLENS,
SEQLEN_OFFSETS, # this could be int or a pointer
x,
cos,
sin,
out,
cu_seqlens,
seq_offsets, # this could be int or a pointer
# Matrix dimensions
seqlen,
rotary_dim,
seqlen_ro,
B,
T,
H,
D,
TR,
# strides
stride_out_batch,
stride_out_seqlen,
stride_out_nheads,
stride_out_headdim,
stride_x_batch,
stride_x_seqlen,
stride_x_nheads,
stride_x_headdim,
# Meta-parameters
BLOCK_K: tl.constexpr,
BLOCK_M: tl.constexpr,
BK: tl.constexpr,
BM: tl.constexpr,
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
IS_VARLEN: tl.constexpr,
INTERLEAVED: tl.constexpr,
CONJUGATE: tl.constexpr
):
pid_m = tl.program_id(axis=0)
pid_batch = tl.program_id(axis=1)
pid_head = tl.program_id(axis=2)
rotary_dim_half = rotary_dim // 2
i_m, i_b, i_h = tl.program_id(0), tl.program_id(1), tl.program_id(2)
DR = D // 2

if not IS_VARLEN:
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
x = x + i_b * T*H*D + i_h * D
out = out + i_b * T*H*D + i_h * D
else:
start_idx = tl.load(CU_SEQLENS + pid_batch)
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
bos, eos = tl.load(cu_seqlens + i_b), tl.load(cu_seqlens + i_b + 1)
T = eos - bos
x = x + bos * H*D + i_h * D
out = out + bos * H*D + i_h * D

if pid_m * BLOCK_M >= seqlen:
if i_m * BM >= T:
return
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rm = i_m * BM + tl.arange(0, BM)
if not IS_SEQLEN_OFFSETS_TENSOR:
rm_cs = rm + SEQLEN_OFFSETS
rm_cs = rm + seq_offsets
else:
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
rk = tl.arange(0, BLOCK_K)
rk_half = tl.arange(0, BLOCK_K // 2)
rm_cs = rm + tl.load(seq_offsets + i_b)
ok, ok2 = tl.arange(0, BK), tl.arange(0, BK // 2)

if not INTERLEAVED:
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
cos = tl.load(COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to(tl.float32)
sin = tl.load(SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)
x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)
x1 = tl.load(
X + rotary_dim_half * stride_x_headdim,
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
# Load the 1st and 2nd halves of x, do calculation, then store to 1st and 2nd halves of out
p_x = x + (rm[:, None] * H*D + ok2[None, :])
p_cos = cos + (rm_cs[:, None] * DR + ok2[None, :])
p_sin = sin + (rm_cs[:, None] * DR + ok2[None, :])
mask = (rm[:, None] < T) & (ok2[None, :] < DR)

b_cos = tl.load(p_cos, mask=mask, other=1.0).to(tl.float32)
b_sin = tl.load(p_sin, mask=mask, other=0.0).to(tl.float32)
b_x0 = tl.load(p_x, mask=mask, other=0.0).to(tl.float32)
b_x1 = tl.load(p_x + DR, mask=mask, other=0.0).to(tl.float32)
if CONJUGATE:
sin = -sin
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
b_sin = -b_sin
b_o0 = b_x0 * b_cos - b_x1 * b_sin
b_o1 = b_x0 * b_sin + b_x1 * b_cos
# write back result
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
tl.store(
OUT + rotary_dim_half * stride_out_headdim,
o1,
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
)
p_out = out + (rm[:, None] * H*D + ok2[None, :])
tl.store(p_out, b_o0, mask=mask)
tl.store(p_out + DR, b_o1, mask=mask)
else:
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
# We don't want to load x[0, 2, 4, ...] and x[1, 3, 5, ...] separately since both are slow.
# Instead, we load x0 = x[0, 1, 2, 3, ...] and x1 = x[1, 0, 3, 2, ...].
# Loading x0 will be fast but x1 will be slow.
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
# Then we load cos = cos[0, 0, 1, 1, ...] and sin = sin[0, 0, 1, 1, ...].
# Then we do the calculation and use tl.where to pick put the right outputs for the even
# and for the odd indices.
rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
rk_repeat = tl.arange(0, BLOCK_K) // 2
X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
cos = tl.load(
COS,
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
other=1.0,
).to(tl.float32)
sin = tl.load(
SIN,
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32)
x1 = tl.load(X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32)
rk_swap = ok + ((ok + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
rk_repeat = tl.arange(0, BK) // 2
p_x0 = x + (rm[:, None] * H*D + ok[None, :])
p_x1 = x + (rm[:, None] * H*D + rk_swap[None, :])
p_cos = cos + (rm_cs[:, None] * DR + rk_repeat[None, :])
p_sin = sin + (rm_cs[:, None] * DR + rk_repeat[None, :])
mask = (rm_cs[:, None] < TR) & (rk_repeat[None, :] < DR)

b_cos = tl.load(p_cos, mask=mask, other=1.0).to(tl.float32)
b_sin = tl.load(p_sin, mask=mask, other=0.0).to(tl.float32)
b_x0 = tl.load(p_x0, mask=mask, other=0.0).to(tl.float32)
b_x1 = tl.load(p_x1, mask=mask, other=0.0).to(tl.float32)
if CONJUGATE:
sin = -sin
x0_cos = x0 * cos
x1_sin = x1 * sin
out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
b_sin = -b_sin
b_x0_cos = b_x0 * b_cos
b_x1_sin = b_x1 * b_sin
b_out = tl.where(ok[None, :] % 2 == 0, b_x0_cos - b_x1_sin, b_x0_cos + b_x1_sin)
p_out = out + (rm[:, None] * H*D + ok[None, :])
tl.store(p_out, b_out, mask=mask)


@contiguous
def rotary_embedding_fwdbwd(
x: torch.Tensor,
cos: torch.Tensor,
Expand All @@ -167,53 +148,51 @@ def rotary_embedding_fwdbwd(
) -> torch.Tensor:
"""
Args:
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
x: (N, T, H, D).
cos: (TR, DR / 2)
sin: (TR, DR / 2)
seqlen_offsets: integer or integer tensor of size (N,)
cu_seqlens: (N + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
y: (N, T, H, D)
"""
is_varlen = cu_seqlens is not None

B, T, H, D = x.shape
if not is_varlen:
batch, seqlen, nheads, headdim = x.shape
N = B
else:
assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
*_, nheads, headdim = x.shape
batch_p_1 = cu_seqlens.shape[0]
batch = batch_p_1 - 1
seqlen = max_seqlen
seqlen_ro, rotary_dim = cos.shape
N, T = cu_seqlens.shape[0] - 1, max_seqlen
TR, DR = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
DR *= 2

assert D <= 256, "Only support D <= 256"
assert TR >= T, "TR must be >= T"
assert DR <= D, "DR must be <= D"

assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"

cos, sin = cos.contiguous(), sin.contiguous()
if isinstance(seqlen_offsets, torch.Tensor):
assert seqlen_offsets.shape == (batch,)
assert seqlen_offsets.shape == (N,)
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
else:
assert seqlen_offsets + seqlen <= seqlen_ro
assert seqlen_offsets + T <= TR

output = torch.empty_like(x) if not inplace else x
if rotary_dim < headdim and not inplace:
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
if DR < D and not inplace:
output[..., DR:].copy_(x[..., DR:])

BLOCK_K = (
BK = (
32
if rotary_dim <= 32
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
if DR <= 32
else (64 if DR <= 64 else (128 if DR <= 128 else 256))
)
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)

def grid(META): return (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
def grid(META): return (triton.cdiv(T, META['BM']), N, H) # noqa
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(x.device.index):
Expand All @@ -224,25 +203,16 @@ def grid(META): return (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) #
output,
cu_seqlens,
seqlen_offsets,
seqlen, # shapes
rotary_dim,
seqlen_ro,
# batch_strides if not varlen else 0
output.stride(0) if not is_varlen else 0,
output.stride(-3), # seqlen_stride or total_seqlen_stride
output.stride(-2), # nheads_stride
output.stride(-1), # headdim_stride
# batch_strides if not varlen else 0
x.stride(0) if not is_varlen else 0,
x.stride(-3), # seqlen stride or total_seqlen_stride
x.stride(-2), # nheads stride
x.stride(-1), # headdim stride
BLOCK_K,
BLOCK_M,
isinstance(seqlen_offsets, torch.Tensor),
is_varlen,
interleaved,
conjugate
B,
T,
H,
DR,
TR,
BK,
IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor),
IS_VARLEN=is_varlen,
INTERLEAVED=interleaved,
CONJUGATE=conjugate
)
return output

Expand Down Expand Up @@ -322,22 +292,20 @@ def rotary_embedding(
):
"""
Args:
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
cos, sin: (seqlen_rotary, rotary_dim / 2)
x: (N, T, H, D)
cos, sin: (TR, DR / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int.
seqlen_offsets: (N,) or int.
Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
cu_seqlens: (batch + 1,) or None
cu_seqlens: (N + 1,) or None
max_seqlen: int
Return:
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
out: (N, T, H, D)
DR must be <= D
Apply rotary embedding to the first DR of x.
"""
return RotaryEmbeddingFunction.apply(
x,
Expand Down Expand Up @@ -495,14 +463,14 @@ def forward(
max_seqlen: Optional[int] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
q: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None else (total_seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None else (total_seqlen, nheads, headdim)
q: (N, T, H, D)
k: (N, T, H, D)
seqlen_offset:
(batch_size,) or int. Each sequence in x is shifted by this amount.
(N,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
If it's a tensor of shape (N,), then to update the cos / sin cache, one
should pass in max_seqlen, which will update the cos / sin cache up to that length.
cu_seqlens: (batch + 1,) or None
cu_seqlens: (N + 1,) or None
max_seqlen: int
"""
if max_seqlen is not None:
Expand Down

0 comments on commit 656657a

Please sign in to comment.