diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index d0ebfba54972..1e30ed29887b 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -861,6 +861,14 @@ def forward(ctx, q, k, v, o, metadata): ENABLE_DROPOUT=metadata.dropout_p > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax ) + ## restore the grid for bwd kernel + best_config = attn_fwd.get_best_config() + block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) + grid = ( + triton.cdiv(metadata.max_seqlens_q, block_m), + nheads_q, + batch + ) ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid @@ -1134,6 +1142,12 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() sm_scale = 0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = N_CTX + input_metadata.max_seqlens_k = N_CTX + if causal: + input_metadata.need_causal() + split_kernel = True dout = torch.randn_like(q) # reference implementation @@ -1148,7 +1162,9 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): ref_dk, k.grad = k.grad.clone(), None ref_dq, q.grad = q.grad.clone(), None # # triton implementation - tri_out, _ = attention(q, k, v, causal, None, sm_scale, 0, False, True) + o = torch.empty_like(q) + tri_out, _ = attention(q, k, v, o, input_metadata) + tri_out.backward(dout)#dout) tri_dv, v.grad = v.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None @@ -1334,4 +1350,4 @@ def main(): run_benchmark(custom_config) if __name__ == '__main__': - sys.exit(main()) \ No newline at end of file + sys.exit(main())