From d2bfdcde0af71e759b832757844254c275c06a53 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Thu, 19 Dec 2024 10:06:50 +0800 Subject: [PATCH] check torch.is_grad_enabled before calling customer flash atten ops --- flash_attn/flash_attn_interface.py | 36 ++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index f83a7728a..bb8cce61d 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -451,8 +451,9 @@ def forward( alibi_slopes, deterministic, return_softmax, + training_mode=True, ): - is_grad = torch.is_grad_enabled() and qkv.requires_grad + is_grad = training_mode and qkv.requires_grad if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach() @@ -516,7 +517,7 @@ def backward(ctx, dout, *args): rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None, None, None class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): @@ -534,8 +535,9 @@ def forward( alibi_slopes, deterministic, return_softmax, + training_mode=True, ): - is_grad = torch.is_grad_enabled() and qkv.requires_grad + is_grad = training_mode and qkv.requires_grad if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach() @@ -609,7 +611,7 @@ def backward(ctx, dout, *args): rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None, None, None, None, None class FlashAttnKVPackedFunc(torch.autograd.Function): @@ -626,8 +628,9 @@ def forward( alibi_slopes, deterministic, return_softmax, + training_mode=True, ): - is_grad = torch.is_grad_enabled() and any( + is_grad = training_mode and any( x.requires_grad for x in [q, kv] ) if softmax_scale is None: @@ -695,7 +698,7 @@ def backward(ctx, dout, *args): ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None + return dq, dkv, None, None, None, None, None, None, None, None, None class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): @@ -716,8 +719,9 @@ def forward( alibi_slopes, deterministic, return_softmax, + training_mode=True, ): - is_grad = torch.is_grad_enabled() and any( + is_grad = training_mode and any( x.requires_grad for x in [q, kv] ) if softmax_scale is None: @@ -798,7 +802,7 @@ def backward(ctx, dout, *args): ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnFunc(torch.autograd.Function): @@ -816,8 +820,9 @@ def forward( alibi_slopes, deterministic, return_softmax, + training_mode=True, ): - is_grad = torch.is_grad_enabled() and any( + is_grad = training_mode and any( x.requires_grad for x in [q, k, v] ) if softmax_scale is None: @@ -883,7 +888,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -906,8 +911,9 @@ def forward( deterministic, return_softmax, block_table, + training_mode=True, ): - is_grad = torch.is_grad_enabled() and any( + is_grad = training_mode and any( x.requires_grad for x in [q, k, v] ) if softmax_scale is None: @@ -987,7 +993,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -1045,6 +1051,7 @@ def flash_attn_qkvpacked_func( alibi_slopes, deterministic, return_attn_probs, + torch.is_grad_enabled(), ) @@ -1122,6 +1129,7 @@ def flash_attn_kvpacked_func( alibi_slopes, deterministic, return_attn_probs, + torch.is_grad_enabled(), ) @@ -1198,6 +1206,7 @@ def flash_attn_func( alibi_slopes, deterministic, return_attn_probs, + torch.is_grad_enabled(), ) @@ -1263,6 +1272,7 @@ def flash_attn_varlen_qkvpacked_func( alibi_slopes, deterministic, return_attn_probs, + torch.is_grad_enabled(), ) @@ -1354,6 +1364,7 @@ def flash_attn_varlen_kvpacked_func( alibi_slopes, deterministic, return_attn_probs, + torch.is_grad_enabled(), ) @@ -1447,6 +1458,7 @@ def flash_attn_varlen_func( deterministic, return_attn_probs, block_table, + torch.is_grad_enabled(), )