Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support bigger head dims #3

Closed
Kademo15 opened this issue Sep 21, 2024 · 5 comments
Closed

Support bigger head dims #3

Kademo15 opened this issue Sep 21, 2024 · 5 comments

Comments

@Kademo15
Copy link

Kademo15 commented Sep 21, 2024

Is it possible to support this version of flash attention, i find that by just editing your code and replacing the 128 with 512 i get an error with stable diffusion 1.5. Is it possible to write the code to dynamically switch the head dim based on the model that's running. Furthermore I don't have enough knowledge to know if just replacing 128 would make it support the 512 version. If you could look into that would be great.

@Beinsezii
Copy link
Owner

Yeah in my quickdif app I have a setup that uses try/catch instead of a dim check to support any flash attention library

import torch
from torch import Tensor
from typing import Callable


def _patch_sdpa(
    patch_func: Callable[[Tensor, Tensor, Tensor, Tensor | None, float, bool, float | None], Tensor],
):
    """(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None)"""

    torch_sdpa = torch.nn.functional.scaled_dot_product_attention

    def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
        try:
            return patch_func(query, key, value, attn_mask, dropout_p, is_causal, scale)
        except Exception:
            hidden_states = torch_sdpa(
                query=query,
                key=key,
                value=value,
                attn_mask=attn_mask,
                dropout_p=dropout_p,
                is_causal=is_causal,
                scale=scale,
            )
        return hidden_states

    torch.nn.functional.scaled_dot_product_attention = sdpa_hijack_flash


try:
    from flash_attn import flash_attn_func

    def sdpa_hijack_flash(q, k, v, m, p, c, s):
        assert m is None
        result = flash_attn_func(
            q=q.transpose(1, 2),
            k=k.transpose(1, 2),
            v=v.transpose(1, 2),
            dropout_p=p,
            softmax_scale=s if s else q.shape[-1] ** (-0.5),
            causal=c,
        )
        assert isinstance(result, Tensor)
        return result.transpose(1, 2)

    _patch_sdpa(sdpa_hijack_flash)
    print("# # #\nPatched SDPA with Flash Attention\n# # #")
except ImportError as e:
    print(f"# # #\nCould not load Flash Attention for hijack:\n{e}\n# # #")

I don't see why it won't also work for ComfyUI, but I haven't used Comfy in a while.

@Kademo15
Copy link
Author

So if I would just slap that code in your node that could work ?

@Beinsezii
Copy link
Owner

Beinsezii commented Sep 22, 2024

Probably. Replace the existing block. Leave the node mappings at the bottom comfy needs those

@sleppyrobot
Copy link

Yep copy/paste works

@Kademo15
Copy link
Author

Kademo15 commented Oct 5, 2024

I guess this can be closed then

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants