Skip to content

Commit

Permalink
Small fixes in fa tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
Ognjen committed Feb 22, 2024
1 parent 9c16a78 commit d0d5e0d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 6 additions & 0 deletions python/perf-kernels/06-fused-attention-fwd-transV.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def forward(ctx, q, k, v, sm_scale):
num_stages = 1
## causal=False likes to pre load v but causal=True does not
pre_load_v = False if causal else True
slice_k_tile = 32
kpack = 1
else:
## D_HEAD = 128
## For fp16, pick BLOCK_M=256, num_warps=8
Expand All @@ -170,6 +172,8 @@ def forward(ctx, q, k, v, sm_scale):
num_warps = BLOCK_M // 32
num_stages = 1
pre_load_v = False
slice_k_tile = 32
kpack = 1

grid = ( triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
Expand All @@ -189,6 +193,8 @@ def forward(ctx, q, k, v, sm_scale):
num_warps = num_warps,
num_stages = num_stages,
pre_load_v = pre_load_v,
slice_k_tile = slice_k_tile,
kpack = kpack,
)

return o
Expand Down
2 changes: 1 addition & 1 deletion python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q,
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
],
Expand Down

0 comments on commit d0d5e0d

Please sign in to comment.