-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support bigger head dims #3
Comments
Yeah in my quickdif app I have a setup that uses try/catch instead of a dim check to support any flash attention library import torch
from torch import Tensor
from typing import Callable
def _patch_sdpa(
patch_func: Callable[[Tensor, Tensor, Tensor, Tensor | None, float, bool, float | None], Tensor],
):
"""(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None)"""
torch_sdpa = torch.nn.functional.scaled_dot_product_attention
def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
try:
return patch_func(query, key, value, attn_mask, dropout_p, is_causal, scale)
except Exception:
hidden_states = torch_sdpa(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
return hidden_states
torch.nn.functional.scaled_dot_product_attention = sdpa_hijack_flash
try:
from flash_attn import flash_attn_func
def sdpa_hijack_flash(q, k, v, m, p, c, s):
assert m is None
result = flash_attn_func(
q=q.transpose(1, 2),
k=k.transpose(1, 2),
v=v.transpose(1, 2),
dropout_p=p,
softmax_scale=s if s else q.shape[-1] ** (-0.5),
causal=c,
)
assert isinstance(result, Tensor)
return result.transpose(1, 2)
_patch_sdpa(sdpa_hijack_flash)
print("# # #\nPatched SDPA with Flash Attention\n# # #")
except ImportError as e:
print(f"# # #\nCould not load Flash Attention for hijack:\n{e}\n# # #") I don't see why it won't also work for ComfyUI, but I haven't used Comfy in a while. |
So if I would just slap that code in your node that could work ? |
Probably. Replace the existing block. Leave the node mappings at the bottom comfy needs those |
Yep copy/paste works |
I guess this can be closed then |
Is it possible to support this version of flash attention, i find that by just editing your code and replacing the 128 with 512 i get an error with stable diffusion 1.5. Is it possible to write the code to dynamically switch the head dim based on the model that's running. Furthermore I don't have enough knowledge to know if just replacing 128 would make it support the 512 version. If you could look into that would be great.
The text was updated successfully, but these errors were encountered: