diff --git a/python/perf-kernels/06-fused-attention-fwd-transV.py b/python/perf-kernels/06-fused-attention-fwd-transV.py index 955eb5106dbc..35a6da764746 100644 --- a/python/perf-kernels/06-fused-attention-fwd-transV.py +++ b/python/perf-kernels/06-fused-attention-fwd-transV.py @@ -159,6 +159,8 @@ def forward(ctx, q, k, v, sm_scale): num_stages = 1 ## causal=False likes to pre load v but causal=True does not pre_load_v = False if causal else True + slice_k_tile = 32 + kpack = 1 else: ## D_HEAD = 128 ## For fp16, pick BLOCK_M=256, num_warps=8 @@ -170,6 +172,8 @@ def forward(ctx, q, k, v, sm_scale): num_warps = BLOCK_M // 32 num_stages = 1 pre_load_v = False + slice_k_tile = 32 + kpack = 1 grid = ( triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) @@ -189,6 +193,8 @@ def forward(ctx, q, k, v, sm_scale): num_warps = num_warps, num_stages = num_stages, pre_load_v = pre_load_v, + slice_k_tile = slice_k_tile, + kpack = kpack, ) return o diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index fe6a608f02e1..6ad01cbadc83 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -87,7 +87,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4), ],