-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for left padding and masking in forward() and generate() #70
base: main
Are you sure you want to change the base?
Conversation
I tried validating the masking implementation with lm-eval-harness. Edit: I found an error in my evaluation of lm-eval-harness with left padding. After fixing, I get 59.1% with left padding + masking, the same as with right-padding. But left padding + no masking also gives 59.1% since lm-eval-harness collates prompts by length which minimizes the number of padding tokens. I found no difference in pythia-1.4b performance left padding with/without masking (52.1% as reported with right padding). Switching to a fixed, random collate function exposes a difference in performance on pythia-1.4b: 43.4% without masking and 52.1% with masking. But mamba-1.4b is virtually unchanged (59.0%). Maybe it's just more robust to long runs of unmasked padding tokens? TL;DR: I think my proposed padding + masking works, though it's not clear mamba even really needs the masking. |
I just want to express my interest in left padding and masking. Thanks for the effort. |
Curious how to mask to train on outputs only if you think masking isn't needed. |
My understanding is that the output of the model (token logits) is causal by default, so there is no masking when the model is being trained autoregressively. For an idea on how to train the model look here: https://github.com/havenhq/mamba-chat/blob/main/train_mamba.py |
I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16) |
I think if you want pad+mask to be effective, you need to do pre-training without using a full sentence in chunk |
Yeah, I got the same error. Does anyone know how to solve it? |
I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?). Can someone verify if my thought process is accurate? |
Could you provide your change in |
config = MambaConfig(**config_data, **kwargs) | ||
model = cls(config, device=device, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you make this change with the kwargs
?
@normster @tridao @albertfgu I believe this feature would be very nice to have in a stable release. Can we work towards merging this into main and have it in the next stable release? I am happy to help in any way. |
@@ -301,11 +305,11 @@ def mamba_inner_fn( | |||
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | |||
out_proj_weight, out_proj_bias, | |||
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, | |||
C_proj_bias=None, delta_softplus=True | |||
C_proj_bias=None, mask=None, delta_softplus=True | |||
): | |||
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some confusion about this line of code. Since MambaInnerFn
doesn't provides parameters option for mask
, it seems that it has no effect one the fwd and bwd pass.
Hence, how can the mask be applied to mark the sequence boundaries?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be wrong, but the PR modified Line 162 in the very same file, which is the definition of MambaInnerFn.forward
method. Also, multiplications on Line 181 and 220 used mask
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! You are right. I might missed this line of code on L162.
But I see in this PR, attention_mask
is only used in the forward pass, and seem not to be used in the backward pass. So when I tried to feed batch data with left padding and masking (batch_size, seq_len, hidden_dim) into mamba block , it reported error. Has anyone encountered a similar error?
Err Msg:
File "/blahblah/miniconda3/envs/dev/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
torch.autograd.backward(
File "/blahblah/miniconda3/envs/dev/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: function MambaInnerFnBackward returned an incorrect number of gradients (expected 16, got 15)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xtwigs pointed out that one can add None
in the returned gradients to fix the issue. (For fellows above who wonder where to put None
, I did it at the end of the tuple.) I agree with @xtwigs as we don't calculate gradient on the mask tensor.
I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)
Yeah, I got the same error. Does anyone know how to solve it?
I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a
None
inmamba_ssm/ops/selective_scan_interface.py#301
fixes this issue.On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).
Can someone verify if my thought process is accurate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xtwigs pointed out that one can add
None
in the returned gradients to fix the issue. (For fellows above who wonder where to putNone
, I did it at the end of the tuple.) I agree with @xtwigs as we don't calculate gradient on the mask tensor.I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)
Yeah, I got the same error. Does anyone know how to solve it?
I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a
None
inmamba_ssm/ops/selective_scan_interface.py#301
fixes this issue.
On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).
Can someone verify if my thought process is accurate?
I have tried this approach but encountered with CUDA OOM, even with much more GPUs and much smaller seq_len
. (8x nodes, 64x A100 GPUs, and seq_len=512 for 1.4B mamba model)
Hi, I also meet this error. Could you please provide some insight about how to fix this issue on the backward pass? |
I have verified your idea and it seems not to work and causes CUDA OOM.
My code changes:
|
Is it using the default configuration defined in |
No. I am evaluating mamba model with 1.4B parameter size, where Here is my experiment details:
|
We are trying this PR because we want mamba to process packed sequence like what has been done in transformer-based models. |
Hi! Sorry, but I haven't run into the OOM issue. I keep a fork of this code here in case it might help. (I also run this with checkpoint_lvl=0) |
Hello, thanks for the sharing! :D I might be wrong. But when I tried this branch and found that OOM still appeared in the backward pass, meanwhile Is there any reproducible test code snippet indicating left padding+ masking works in both fwd and bwd pass? |
Can you try this while disabling the dropout module added in the mamba simple code? (default was set to 0.1) |
Hi xtwigs, many thanks for your reply! |
Hi zigzagcai , have you solved the problem of masking in the mamba block ? |
6d45666
to
41d30ce
Compare
This PR implements masking for left contiguous pad tokens by zeroing out intermediate state values, per the discussion at #66, for all three code paths: non-fused, fused without CUDA graph, and fused with CUDA graph. I'm not sure if this implementation is the best approach, so let me know if there's a better way to do things.
I've included a simple testing script at
tests/test_padding.py
which can be run withpython tests/test_padding.py
to compare prefill logits + generation outputs with and without left padding.I also evaluated the models with/without batching + left-padding + masking on a question answering dataset and found nearly identical accuracies. Batching + left-padding + no masking hurts accuracy by a couple percentage points.