Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check torch.is_grad_enabled before calling customer flash atten ops #1397

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,9 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
training_mode=True,
):
is_grad = torch.is_grad_enabled() and qkv.requires_grad
is_grad = training_mode and qkv.requires_grad
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach()
Expand Down Expand Up @@ -516,7 +517,7 @@ def backward(ctx, dout, *args):
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None, None
return dqkv, None, None, None, None, None, None, None, None, None


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
Expand All @@ -534,8 +535,9 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
training_mode=True,
):
is_grad = torch.is_grad_enabled() and qkv.requires_grad
is_grad = training_mode and qkv.requires_grad
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
Expand Down Expand Up @@ -609,7 +611,7 @@ def backward(ctx, dout, *args):
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None, None, None, None
return dqkv, None, None, None, None, None, None, None, None, None, None, None


class FlashAttnKVPackedFunc(torch.autograd.Function):
Expand All @@ -626,8 +628,9 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
training_mode=True,
):
is_grad = torch.is_grad_enabled() and any(
is_grad = training_mode and any(
x.requires_grad for x in [q, kv]
)
if softmax_scale is None:
Expand Down Expand Up @@ -695,7 +698,7 @@ def backward(ctx, dout, *args):
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None
return dq, dkv, None, None, None, None, None, None, None, None, None


class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Expand All @@ -716,8 +719,9 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
training_mode=True,
):
is_grad = torch.is_grad_enabled() and any(
is_grad = training_mode and any(
x.requires_grad for x in [q, kv]
)
if softmax_scale is None:
Expand Down Expand Up @@ -798,7 +802,7 @@ def backward(ctx, dout, *args):
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None


class FlashAttnFunc(torch.autograd.Function):
Expand All @@ -816,8 +820,9 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
training_mode=True,
):
is_grad = torch.is_grad_enabled() and any(
is_grad = training_mode and any(
x.requires_grad for x in [q, k, v]
)
if softmax_scale is None:
Expand Down Expand Up @@ -883,7 +888,7 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None


class FlashAttnVarlenFunc(torch.autograd.Function):
Expand All @@ -906,8 +911,9 @@ def forward(
deterministic,
return_softmax,
block_table,
training_mode=True,
):
is_grad = torch.is_grad_enabled() and any(
is_grad = training_mode and any(
x.requires_grad for x in [q, k, v]
)
if softmax_scale is None:
Expand Down Expand Up @@ -987,7 +993,7 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None


def flash_attn_qkvpacked_func(
Expand Down Expand Up @@ -1045,6 +1051,7 @@ def flash_attn_qkvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
torch.is_grad_enabled(),
)


Expand Down Expand Up @@ -1122,6 +1129,7 @@ def flash_attn_kvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
torch.is_grad_enabled(),
)


Expand Down Expand Up @@ -1198,6 +1206,7 @@ def flash_attn_func(
alibi_slopes,
deterministic,
return_attn_probs,
torch.is_grad_enabled(),
)


Expand Down Expand Up @@ -1263,6 +1272,7 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
torch.is_grad_enabled(),
)


Expand Down Expand Up @@ -1354,6 +1364,7 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
torch.is_grad_enabled(),
)


Expand Down Expand Up @@ -1447,6 +1458,7 @@ def flash_attn_varlen_func(
deterministic,
return_attn_probs,
block_table,
torch.is_grad_enabled(),
)


Expand Down