Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention masking issue for batch submissions, Huggingface #16

Open
prwoolley opened this issue Dec 17, 2024 · 3 comments
Open

Attention masking issue for batch submissions, Huggingface #16

prwoolley opened this issue Dec 17, 2024 · 3 comments

Comments

@prwoolley
Copy link

Thanks for the model! This is regarding the default behavior on Huggingface. When running a batch forward pass on the model for inference, there is an issue with the attention mask created by the tokenizer. Feeding the model the attention mask tensor throws an error because the tokenizer makes the attention mask as integers whereas a downstream step expects floats. This can be fixed by simply changing the datatype to a float before the forward pass, but this is another step for the user to figure out. Can this become a default tokenizer step?

@prwoolley
Copy link
Author

There might be more issues than just this. It worked in the case of the two sequences I tested this on, but got an xformers error when I tried a larger batch of more diverse sequence lengths. I tried changing the dtype to float16 and padded the sequences to be multiples of 8 but to no avail.

NotImplementedError: No operator found for memory_efficient_attention_forward with inputs:
query : shape=(512, 91, 10, 64) (torch.float32)
key : shape=(512, 91, 10, 64) (torch.float32)
value : shape=(512, 91, 10, 64) (torch.float32)
attn_bias : <class 'torch.Tensor'>
p : 0
[email protected] is not supported because:
dtype=torch.float32 (supported: {torch.bfloat16, torch.float16})
attn_bias type is <class 'torch.Tensor'>
cutlassF-pt is not supported because:
attn_bias.stride(-2) % 4 != 0 (attn_bias.stride() = (82810, 8281, 91, 1))
attn_bias.stride(-2) % 4 != 0 (attn_bias.stride() = (82810, 8281, 91, 1))
attn_bias.stride(-2) % 4 != 0 (attn_bias.stride() = (82810, 8281, 91, 1))
HINT: To use an attn_bias with a sequence length that is not a multiple of 8, you need to ensure memory is aligned by slicing a bigger tensor. Example: use attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5] instead of torch.zeros([1, 1, 5, 5])

I saw in a different thread that the CPU version uses default torch instead of xformers. Is there a way to add functionality so that the GPU version can also disable xformers?

@prwoolley
Copy link
Author

Placing my fix here, in case people need it for batching tokens. I commented out a portion of the amplify.py script located here (on my machine).
Location: ~/.cache/huggingface/modules/transformers_modules/chandar-lab/AMPLIFY_350M/a735ae820dbf73d7c770623d70cee565038e0f68

portion of the script changed (removing xformers memory_efficient_attention implementation)

# Compute the attention using xformers if the tensors are on GPU
        #if x.is_cuda:
        #    # Input and output are of dimension (B, M, H, K) where B is the batch size, M the sequence length,
        #    # H the number of heads, and K the embeding size per head
        #    attn = memory_efficient_attention(
        #        query=xq,
        #        key=xk,
        #        value=xv,
        #        attn_bias=attention_mask,
        #        p=self.config.dropout_prob if self.training else 0,
        #    )
        #else:
        #    # Input and output are of dimension (B, H, M, K)
        #    attn = scaled_dot_product_attention(
        #        query=xq.transpose(1, 2),
        #        key=xk.transpose(1, 2),
        #        value=xv.transpose(1, 2),
        #        attn_mask=attention_mask,
        #        dropout_p=self.config.dropout_prob if self.training else 0,
        #    ).transpose(1, 2)

        attn = scaled_dot_product_attention(
            query=xq.transpose(1, 2),
            key=xk.transpose(1, 2),
            value=xv.transpose(1, 2),
            attn_mask=attention_mask,
            dropout_p=self.config.dropout_prob if self.training else 0,
        ).transpose(1, 2)

@qfournier
Copy link
Collaborator

Hi @prwoolley, thank you for your interest in AMPLIFY! Apologies for the delayed response.

You’re right, the pad_mask dtype should be handled automatically in the forward pass. It can be addressed by determining whether you are using torch.float32, torch.float16, or torch.bfloat16, and casting the pad_mask to the appropriate dtype.

Alternatively, you could verify the pad_mask dtype before using the xformers implementation:

        # Compute the attention using xformers if the tensors are on GPU
        if x.is_cuda and pad_mask.dtype == x.dtype:
            # Input and output are of dimension (B, M, H, K) where B is the batch size, M the sequence length,
            # H the number of heads, and K the embedding size per head
            attn = memory_efficient_attention(
                query=xq,
                key=xk,
                value=xv,
                attn_bias=pad_mask,
                p=self.config.dropout_prob if self.training else 0,
            )
        else:
            # Input and output are of dimension (B, H, M, K)
            attn = scaled_dot_product_attention(
                query=xq.transpose(1, 2),
                key=xk.transpose(1, 2),
                value=xv.transpose(1, 2),
                attn_mask=pad_mask,
                dropout_p=self.config.dropout_prob if self.training else 0,
            ).transpose(1, 2)

The second issue you observed might be related to the pad_mask dtype.

Our group is working on an improved version of the codebase. We will fix this in the next release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants