Skip to content

Commit 7e0a972

Browse files
committed
[Misc.] Fix ptr overflows (#138)
1 parent b9a5cfb commit 7e0a972

File tree

4 files changed

+71
-39
lines changed

4 files changed

+71
-39
lines changed

fla/ops/common/chunk_h.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,15 @@ def chunk_fwd_kernel_h(
7474
if HEAD_FIRST:
7575
p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
7676
p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
77-
p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
77+
78+
o_h = (i_nh * NT + i_t).to(tl.int64) * K*V
79+
p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
7880
else:
7981
p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
8082
p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
81-
p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
83+
84+
o_h = ((boh + i_t) * H + i_h).to(tl.int64) * K*V
85+
p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
8286

8387
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
8488
# [BK, BT]
@@ -204,9 +208,11 @@ def chunk_bwd_kernel_dh(
204208

205209
for i_t in range(NT - 1, -1, -1):
206210
if HEAD_FIRST:
207-
p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
211+
o_dh = (i_nh * NT + i_t).to(tl.int64) * K*V
212+
p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
208213
else:
209-
p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
214+
o_dh = ((boh + i_t) * H + i_h).to(tl.int64) * K*V
215+
p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
210216
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
211217
last_idx = min(i_t * BT + BT, T) - 1
212218
# [BK, BT]

fla/ops/common/chunk_o.py

+22-22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# -*- coding: utf-8 -*-
2-
# Copyright (c) 2024, Songlin Yang, Yu Zhang
2+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

44
from typing import Optional, Tuple
55

@@ -68,7 +68,7 @@ def chunk_fwd_kernel_o(
6868
k += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
6969
v += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V)
7070
o += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V)
71-
h += ((i_bh * NT + i_t) * K*V) if HEAD_FIRST else ((i_tg * H + i_h) * K*V)
71+
h += ((i_bh * NT + i_t).to(tl.int64) * K*V) if HEAD_FIRST else ((i_tg * H + i_h).to(tl.int64) * K*V)
7272

7373
b_o = tl.zeros([BT, BV], dtype=tl.float32)
7474
b_A = tl.zeros([BT, BT], dtype=tl.float32)
@@ -170,23 +170,23 @@ def chunk_bwd_kernel_dqkwg(
170170
bos, eos = i_b * T, i_b * T + T
171171

172172
# offset calculation
173-
v += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
174-
do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
175-
h += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K*V
176-
dh += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K*V
177-
q += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
178-
k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
179-
dq += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
180-
dk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
173+
v += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
174+
do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
175+
h += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
176+
dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
177+
q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
178+
k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
179+
dq += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
180+
dk += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
181181
s_qk = K if HEAD_FIRST else H*K
182182
s_vo = V if HEAD_FIRST else H*V
183183
s_g = 1 if HEAD_FIRST else H
184184

185185
# for delta rule only
186186
if USE_DW:
187-
dw += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
188-
dv += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
189-
w += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
187+
dw += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
188+
dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
189+
w += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
190190

191191
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
192192
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
@@ -331,14 +331,14 @@ def chunk_bwd_kernel_dv(
331331
b_dv = tl.zeros([BT, BV], dtype=tl.float32)
332332

333333
# offset calculation
334-
q += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
335-
k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
336-
do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
337-
dv += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
334+
q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
335+
k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
336+
do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
337+
dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
338338
s_qk = K if HEAD_FIRST else H*K
339339
s_vo = V if HEAD_FIRST else H*V
340340
s_g = 1 if HEAD_FIRST else H
341-
dh += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K*V
341+
dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
342342

343343
b_A = tl.zeros([BT, BT], dtype=tl.float32)
344344
for i_k in range(tl.cdiv(K, BK)):
@@ -414,10 +414,10 @@ def chunk_bwd_kernel_dv_local(
414414
bos, eos = i_b * T, i_b * T + T
415415

416416
# offset calculation
417-
q += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
418-
k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
419-
do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
420-
dv += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
417+
q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
418+
k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
419+
do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
420+
dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
421421
s_qk = K if HEAD_FIRST else H*K
422422
s_vo = V if HEAD_FIRST else H*V
423423
s_g = 1 if HEAD_FIRST else H

fla/ops/retention/fused_chunk.py

+38-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# -*- coding: utf-8 -*-
2-
# Copyright (c) 2024, Songlin Yang, Yu Zhang
2+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

44
from typing import Optional, Tuple
55

@@ -55,21 +55,21 @@ def fused_chunk_retention_fwd_kernel(
5555
p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
5656
p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
5757
p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
58-
p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
58+
p_o = tl.make_block_ptr(o + (i_k*B*H+i_bh).to(tl.int64) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
5959

6060
if USE_INITIAL_STATE:
6161
p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
6262
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
6363

6464
NT = tl.cdiv(T, BT)
6565
for i in range(0, NT):
66+
# [BT, BK]
67+
b_q = tl.load(p_q, boundary_check=(0, 1))
68+
b_q = (b_q * scale).to(b_q.dtype)
6669
# [BK, BT]
6770
b_k = tl.load(p_k, boundary_check=(0, 1))
6871
# [BT, BV]
6972
b_v = tl.load(p_v, boundary_check=(0, 1))
70-
# [BT, BK]
71-
b_q = tl.load(p_q, boundary_check=(0, 1))
72-
b_q = (b_q * scale).to(b_k.dtype)
7373

7474
# [BT, BT]
7575
b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s
@@ -138,7 +138,7 @@ def fused_chunk_retention_bwd_kernel(
138138
p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
139139
p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
140140
p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
141-
p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0))
141+
p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H).to(tl.int64) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0))
142142

143143
# [BT, K]
144144
b_k = tl.load(p_k, boundary_check=(0, 1))
@@ -174,8 +174,8 @@ def fused_chunk_retention_bwd_kernel(
174174
p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
175175
p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
176176
p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
177-
p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
178-
p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
177+
p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H).to(tl.int64) * T*K, (T, K), (K, 1), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
178+
p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H).to(tl.int64) * T*V, (T, V), (V, 1), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
179179
# [K, BT]
180180
b_q = tl.load(p_q, boundary_check=(0, 1))
181181
# [BT, BK]
@@ -244,9 +244,21 @@ def forward(ctx, q, k, v, scale, initial_state, output_final_state):
244244

245245
grid = (NV, NK, B * H)
246246
fused_chunk_retention_fwd_kernel[grid](
247-
q, k, v, o, initial_state, final_state,
247+
q,
248+
k,
249+
v,
250+
o,
251+
initial_state,
252+
final_state,
248253
scale,
249-
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
254+
B=B,
255+
H=H,
256+
T=T,
257+
K=K,
258+
V=V,
259+
BT=BT,
260+
BK=BK,
261+
BV=BV,
250262
USE_INITIAL_STATE=initial_state is not None,
251263
STORE_FINAL_STATE=output_final_state,
252264
CHECK=CHECK,
@@ -279,9 +291,23 @@ def backward(ctx, do, dht=None):
279291
grid = (NV, NK, B * H)
280292

281293
fused_chunk_retention_bwd_kernel[grid](
282-
q, k, v, do, dq, dk, dv, initial_state,
294+
q,
295+
k,
296+
v,
297+
do,
298+
dq,
299+
dk,
300+
dv,
301+
initial_state,
283302
scale,
284-
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
303+
B=B,
304+
T=T,
305+
H=H,
306+
K=K,
307+
V=V,
308+
BT=BT,
309+
BK=BK,
310+
BV=BV,
285311
USE_INITIAL_STATE=initial_state is not None,
286312
CHECK=ctx.CHECK,
287313
num_warps=num_warps,

fla/ops/simple_gla/chunk.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# -*- coding: utf-8 -*-
2-
# Copyright (c) 2024, Songlin Yang, Yu Zhang
2+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

44
from typing import Optional, Tuple
55

0 commit comments

Comments
 (0)