From 656657acb7d58311e7cf42ba1bca850a99741526 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Wed, 15 Jan 2025 02:38:30 +0800 Subject: [PATCH] [Rotary] Fix performance drop for varlen (#117) --- fla/modules/rotary.py | 270 +++++++++++++++++++----------------------- 1 file changed, 119 insertions(+), 151 deletions(-) diff --git a/fla/modules/rotary.py b/fla/modules/rotary.py index 579ddcc17..7e0eb1fca 100644 --- a/fla/modules/rotary.py +++ b/fla/modules/rotary.py @@ -20,140 +20,121 @@ def rotate_half(x, interleaved=False): return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) + return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2) def rotary_embedding_ref(x, cos, sin, interleaved=False): """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + x: (N, T, H, D) + cos, sin: (TR, DR / 2) or (N, TR, DR / 2) """ ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] - cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") - sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + cos = repeat(cos, '... d -> ... 1 (2 d)' if not interleaved else '... d -> ... 1 (d 2)') + sin = repeat(sin, '... d -> ... 1 (2 d)' if not interleaved else '... d -> ... 1 (d 2)') return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], -1) @triton.autotune( configs=[ - triton.Config({}, num_warps=num_warps) - for num_warps in [2, 4, 8, 16, 32] + triton.Config({'BM': BM}, num_warps=num_warps) + for BM in [4, 8, 16, 32, 64, 128] + for num_warps in [2, 4, 8, 16] ], - key=["BLOCK_K", "BLOCK_M", "INTERLEAVED"], + key=['B', 'T', 'H', 'INTERLEAVED'], ) @triton.jit def rotary_embedding_kernel( - X, - COS, - SIN, - OUT, - CU_SEQLENS, - SEQLEN_OFFSETS, # this could be int or a pointer + x, + cos, + sin, + out, + cu_seqlens, + seq_offsets, # this could be int or a pointer # Matrix dimensions - seqlen, - rotary_dim, - seqlen_ro, + B, + T, + H, + D, + TR, # strides - stride_out_batch, - stride_out_seqlen, - stride_out_nheads, - stride_out_headdim, - stride_x_batch, - stride_x_seqlen, - stride_x_nheads, - stride_x_headdim, # Meta-parameters - BLOCK_K: tl.constexpr, - BLOCK_M: tl.constexpr, + BK: tl.constexpr, + BM: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr ): - pid_m = tl.program_id(axis=0) - pid_batch = tl.program_id(axis=1) - pid_head = tl.program_id(axis=2) - rotary_dim_half = rotary_dim // 2 + i_m, i_b, i_h = tl.program_id(0), tl.program_id(1), tl.program_id(2) + DR = D // 2 if not IS_VARLEN: - X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads - OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + x = x + i_b * T*H*D + i_h * D + out = out + i_b * T*H*D + i_h * D else: - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads - OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + bos, eos = tl.load(cu_seqlens + i_b), tl.load(cu_seqlens + i_b + 1) + T = eos - bos + x = x + bos * H*D + i_h * D + out = out + bos * H*D + i_h * D - if pid_m * BLOCK_M >= seqlen: + if i_m * BM >= T: return - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rm = i_m * BM + tl.arange(0, BM) if not IS_SEQLEN_OFFSETS_TENSOR: - rm_cs = rm + SEQLEN_OFFSETS + rm_cs = rm + seq_offsets else: - rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) - rk = tl.arange(0, BLOCK_K) - rk_half = tl.arange(0, BLOCK_K // 2) + rm_cs = rm + tl.load(seq_offsets + i_b) + ok, ok2 = tl.arange(0, BK), tl.arange(0, BK // 2) if not INTERLEAVED: - # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT - X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) - COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) - SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) - cos = tl.load(COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to(tl.float32) - sin = tl.load(SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32) - x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32) - x1 = tl.load( - X + rotary_dim_half * stride_x_headdim, - mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), - other=0.0, - ).to(tl.float32) + # Load the 1st and 2nd halves of x, do calculation, then store to 1st and 2nd halves of out + p_x = x + (rm[:, None] * H*D + ok2[None, :]) + p_cos = cos + (rm_cs[:, None] * DR + ok2[None, :]) + p_sin = sin + (rm_cs[:, None] * DR + ok2[None, :]) + mask = (rm[:, None] < T) & (ok2[None, :] < DR) + + b_cos = tl.load(p_cos, mask=mask, other=1.0).to(tl.float32) + b_sin = tl.load(p_sin, mask=mask, other=0.0).to(tl.float32) + b_x0 = tl.load(p_x, mask=mask, other=0.0).to(tl.float32) + b_x1 = tl.load(p_x + DR, mask=mask, other=0.0).to(tl.float32) if CONJUGATE: - sin = -sin - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos + b_sin = -b_sin + b_o0 = b_x0 * b_cos - b_x1 * b_sin + b_o1 = b_x0 * b_sin + b_x1 * b_cos # write back result - OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) - tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) - tl.store( - OUT + rotary_dim_half * stride_out_headdim, - o1, - mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), - ) + p_out = out + (rm[:, None] * H*D + ok2[None, :]) + tl.store(p_out, b_o0, mask=mask) + tl.store(p_out + DR, b_o1, mask=mask) else: - # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow. - # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...]. + # We don't want to load x[0, 2, 4, ...] and x[1, 3, 5, ...] separately since both are slow. + # Instead, we load x0 = x[0, 1, 2, 3, ...] and x1 = x[1, 0, 3, 2, ...]. # Loading x0 will be fast but x1 will be slow. - # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...]. + # Then we load cos = cos[0, 0, 1, 1, ...] and sin = sin[0, 0, 1, 1, ...]. # Then we do the calculation and use tl.where to pick put the right outputs for the even # and for the odd indices. - rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... - rk_repeat = tl.arange(0, BLOCK_K) // 2 - X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) - X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim) - COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) - SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) - cos = tl.load( - COS, - mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), - other=1.0, - ).to(tl.float32) - sin = tl.load( - SIN, - mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), - other=0.0, - ).to(tl.float32) - x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32) - x1 = tl.load(X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32) + rk_swap = ok + ((ok + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + rk_repeat = tl.arange(0, BK) // 2 + p_x0 = x + (rm[:, None] * H*D + ok[None, :]) + p_x1 = x + (rm[:, None] * H*D + rk_swap[None, :]) + p_cos = cos + (rm_cs[:, None] * DR + rk_repeat[None, :]) + p_sin = sin + (rm_cs[:, None] * DR + rk_repeat[None, :]) + mask = (rm_cs[:, None] < TR) & (rk_repeat[None, :] < DR) + + b_cos = tl.load(p_cos, mask=mask, other=1.0).to(tl.float32) + b_sin = tl.load(p_sin, mask=mask, other=0.0).to(tl.float32) + b_x0 = tl.load(p_x0, mask=mask, other=0.0).to(tl.float32) + b_x1 = tl.load(p_x1, mask=mask, other=0.0).to(tl.float32) if CONJUGATE: - sin = -sin - x0_cos = x0 * cos - x1_sin = x1 * sin - out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) - OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) - tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) + b_sin = -b_sin + b_x0_cos = b_x0 * b_cos + b_x1_sin = b_x1 * b_sin + b_out = tl.where(ok[None, :] % 2 == 0, b_x0_cos - b_x1_sin, b_x0_cos + b_x1_sin) + p_out = out + (rm[:, None] * H*D + ok[None, :]) + tl.store(p_out, b_out, mask=mask) +@contiguous def rotary_embedding_fwdbwd( x: torch.Tensor, cos: torch.Tensor, @@ -167,53 +148,51 @@ def rotary_embedding_fwdbwd( ) -> torch.Tensor: """ Args: - x: (batch, seqlen, nheads, headdim) if cu_seqlens is None else (total_seqlen, nheads, headdim). - cos: (seqlen_ro, rotary_dim / 2) - sin: (seqlen_ro, rotary_dim / 2) - seqlen_offsets: integer or integer tensor of size (batch,) - cu_seqlens: (batch + 1,) or None + x: (N, T, H, D). + cos: (TR, DR / 2) + sin: (TR, DR / 2) + seqlen_offsets: integer or integer tensor of size (N,) + cu_seqlens: (N + 1,) or None max_seqlen: int Returns: - y: (batch, seqlen, nheads, headdim) + y: (N, T, H, D) """ is_varlen = cu_seqlens is not None + + B, T, H, D = x.shape if not is_varlen: - batch, seqlen, nheads, headdim = x.shape + N = B else: assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" - *_, nheads, headdim = x.shape - batch_p_1 = cu_seqlens.shape[0] - batch = batch_p_1 - 1 - seqlen = max_seqlen - seqlen_ro, rotary_dim = cos.shape + N, T = cu_seqlens.shape[0] - 1, max_seqlen + TR, DR = cos.shape assert sin.shape == cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim, "rotary_dim must be <= headdim" - assert headdim <= 256, "Only support headdim <= 256" - assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + DR *= 2 + + assert D <= 256, "Only support D <= 256" + assert TR >= T, "TR must be >= T" + assert DR <= D, "DR must be <= D" assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" - cos, sin = cos.contiguous(), sin.contiguous() if isinstance(seqlen_offsets, torch.Tensor): - assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.shape == (N,) assert seqlen_offsets.dtype in [torch.int32, torch.int64] else: - assert seqlen_offsets + seqlen <= seqlen_ro + assert seqlen_offsets + T <= TR output = torch.empty_like(x) if not inplace else x - if rotary_dim < headdim and not inplace: - output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + if DR < D and not inplace: + output[..., DR:].copy_(x[..., DR:]) - BLOCK_K = ( + BK = ( 32 - if rotary_dim <= 32 - else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) + if DR <= 32 + else (64 if DR <= 64 else (128 if DR <= 128 else 256)) ) - BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) - def grid(META): return (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa + def grid(META): return (triton.cdiv(T, META['BM']), N, H) # noqa # Need this, otherwise Triton tries to launch from cuda:0 and we get # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) with torch.cuda.device(x.device.index): @@ -224,25 +203,16 @@ def grid(META): return (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # output, cu_seqlens, seqlen_offsets, - seqlen, # shapes - rotary_dim, - seqlen_ro, - # batch_strides if not varlen else 0 - output.stride(0) if not is_varlen else 0, - output.stride(-3), # seqlen_stride or total_seqlen_stride - output.stride(-2), # nheads_stride - output.stride(-1), # headdim_stride - # batch_strides if not varlen else 0 - x.stride(0) if not is_varlen else 0, - x.stride(-3), # seqlen stride or total_seqlen_stride - x.stride(-2), # nheads stride - x.stride(-1), # headdim stride - BLOCK_K, - BLOCK_M, - isinstance(seqlen_offsets, torch.Tensor), - is_varlen, - interleaved, - conjugate + B, + T, + H, + DR, + TR, + BK, + IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), + IS_VARLEN=is_varlen, + INTERLEAVED=interleaved, + CONJUGATE=conjugate ) return output @@ -322,22 +292,20 @@ def rotary_embedding( ): """ Args: - x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim) - cos, sin: (seqlen_rotary, rotary_dim / 2) + x: (N, T, H, D) + cos, sin: (TR, DR / 2) interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). inplace: if True, apply rotary embedding in-place. - seqlen_offsets: (batch_size,) or int. + seqlen_offsets: (N,) or int. Each sequence in x is shifted by this amount. Most commonly used in inference when we have KV cache. - cu_seqlens: (batch + 1,) or None + cu_seqlens: (N + 1,) or None max_seqlen: int Return: - out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim) - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. + out: (N, T, H, D) + DR must be <= D + Apply rotary embedding to the first DR of x. """ return RotaryEmbeddingFunction.apply( x, @@ -495,14 +463,14 @@ def forward( max_seqlen: Optional[int] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ - q: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None else (total_seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None else (total_seqlen, nheads, headdim) + q: (N, T, H, D) + k: (N, T, H, D) seqlen_offset: - (batch_size,) or int. Each sequence in x is shifted by this amount. + (N,) or int. Each sequence in x is shifted by this amount. Most commonly used in inference when we have KV cache. - If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one + If it's a tensor of shape (N,), then to update the cos / sin cache, one should pass in max_seqlen, which will update the cos / sin cache up to that length. - cu_seqlens: (batch + 1,) or None + cu_seqlens: (N + 1,) or None max_seqlen: int """ if max_seqlen is not None: