|
1 | 1 | # -*- coding: utf-8 -*-
|
2 |
| -# Copyright (c) 2024, Songlin Yang, Yu Zhang |
| 2 | +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang |
3 | 3 |
|
4 | 4 | from typing import Optional, Tuple
|
5 | 5 |
|
@@ -55,21 +55,21 @@ def fused_chunk_retention_fwd_kernel(
|
55 | 55 | p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
|
56 | 56 | p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
|
57 | 57 | p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
|
58 |
| - p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) |
| 58 | + p_o = tl.make_block_ptr(o + (i_k*B*H+i_bh).to(tl.int64) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) |
59 | 59 |
|
60 | 60 | if USE_INITIAL_STATE:
|
61 | 61 | p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
62 | 62 | b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
63 | 63 |
|
64 | 64 | NT = tl.cdiv(T, BT)
|
65 | 65 | for i in range(0, NT):
|
| 66 | + # [BT, BK] |
| 67 | + b_q = tl.load(p_q, boundary_check=(0, 1)) |
| 68 | + b_q = (b_q * scale).to(b_q.dtype) |
66 | 69 | # [BK, BT]
|
67 | 70 | b_k = tl.load(p_k, boundary_check=(0, 1))
|
68 | 71 | # [BT, BV]
|
69 | 72 | b_v = tl.load(p_v, boundary_check=(0, 1))
|
70 |
| - # [BT, BK] |
71 |
| - b_q = tl.load(p_q, boundary_check=(0, 1)) |
72 |
| - b_q = (b_q * scale).to(b_k.dtype) |
73 | 73 |
|
74 | 74 | # [BT, BT]
|
75 | 75 | b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s
|
@@ -138,7 +138,7 @@ def fused_chunk_retention_bwd_kernel(
|
138 | 138 | p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
|
139 | 139 | p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
|
140 | 140 | p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
|
141 |
| - p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0)) |
| 141 | + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H).to(tl.int64) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0)) |
142 | 142 |
|
143 | 143 | # [BT, K]
|
144 | 144 | b_k = tl.load(p_k, boundary_check=(0, 1))
|
@@ -174,8 +174,8 @@ def fused_chunk_retention_bwd_kernel(
|
174 | 174 | p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
|
175 | 175 | p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
|
176 | 176 | p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
|
177 |
| - p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) |
178 |
| - p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) |
| 177 | + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H).to(tl.int64) * T*K, (T, K), (K, 1), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) |
| 178 | + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H).to(tl.int64) * T*V, (T, V), (V, 1), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) |
179 | 179 | # [K, BT]
|
180 | 180 | b_q = tl.load(p_q, boundary_check=(0, 1))
|
181 | 181 | # [BT, BK]
|
@@ -244,9 +244,21 @@ def forward(ctx, q, k, v, scale, initial_state, output_final_state):
|
244 | 244 |
|
245 | 245 | grid = (NV, NK, B * H)
|
246 | 246 | fused_chunk_retention_fwd_kernel[grid](
|
247 |
| - q, k, v, o, initial_state, final_state, |
| 247 | + q, |
| 248 | + k, |
| 249 | + v, |
| 250 | + o, |
| 251 | + initial_state, |
| 252 | + final_state, |
248 | 253 | scale,
|
249 |
| - B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, |
| 254 | + B=B, |
| 255 | + H=H, |
| 256 | + T=T, |
| 257 | + K=K, |
| 258 | + V=V, |
| 259 | + BT=BT, |
| 260 | + BK=BK, |
| 261 | + BV=BV, |
250 | 262 | USE_INITIAL_STATE=initial_state is not None,
|
251 | 263 | STORE_FINAL_STATE=output_final_state,
|
252 | 264 | CHECK=CHECK,
|
@@ -279,9 +291,23 @@ def backward(ctx, do, dht=None):
|
279 | 291 | grid = (NV, NK, B * H)
|
280 | 292 |
|
281 | 293 | fused_chunk_retention_bwd_kernel[grid](
|
282 |
| - q, k, v, do, dq, dk, dv, initial_state, |
| 294 | + q, |
| 295 | + k, |
| 296 | + v, |
| 297 | + do, |
| 298 | + dq, |
| 299 | + dk, |
| 300 | + dv, |
| 301 | + initial_state, |
283 | 302 | scale,
|
284 |
| - B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, |
| 303 | + B=B, |
| 304 | + T=T, |
| 305 | + H=H, |
| 306 | + K=K, |
| 307 | + V=V, |
| 308 | + BT=BT, |
| 309 | + BK=BK, |
| 310 | + BV=BV, |
285 | 311 | USE_INITIAL_STATE=initial_state is not None,
|
286 | 312 | CHECK=ctx.CHECK,
|
287 | 313 | num_warps=num_warps,
|
|
0 commit comments