Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for left padding and masking in forward() and generate() #70

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

normster
Copy link

This PR implements masking for left contiguous pad tokens by zeroing out intermediate state values, per the discussion at #66, for all three code paths: non-fused, fused without CUDA graph, and fused with CUDA graph. I'm not sure if this implementation is the best approach, so let me know if there's a better way to do things.

I've included a simple testing script at tests/test_padding.py which can be run with python tests/test_padding.py to compare prefill logits + generation outputs with and without left padding.

I also evaluated the models with/without batching + left-padding + masking on a question answering dataset and found nearly identical accuracies. Batching + left-padding + no masking hurts accuracy by a couple percentage points.

@normster
Copy link
Author

normster commented Dec 20, 2023

I tried validating the masking implementation with lm-eval-harness. On HellaSwag, mamba-1.4b with right padding still achieves the reported 59.1% accuracy. Switching to left padding drops this to 55.8% accuracy, and changing lm-eval-harness to 1) construct padding masks and 2) use padding masks does not recover any performance (still 55.8%). I might be using the padding masks incorrectly but it's pretty straightforward so I suspect the issue might lie in my mamba masking change in this PR.

Edit: I found an error in my evaluation of lm-eval-harness with left padding. After fixing, I get 59.1% with left padding + masking, the same as with right-padding. But left padding + no masking also gives 59.1% since lm-eval-harness collates prompts by length which minimizes the number of padding tokens. I found no difference in pythia-1.4b performance left padding with/without masking (52.1% as reported with right padding). Switching to a fixed, random collate function exposes a difference in performance on pythia-1.4b: 43.4% without masking and 52.1% with masking. But mamba-1.4b is virtually unchanged (59.0%). Maybe it's just more robust to long runs of unmasked padding tokens?

TL;DR: I think my proposed padding + masking works, though it's not clear mamba even really needs the masking.

@pjsample
Copy link

I just want to express my interest in left padding and masking. Thanks for the effort.

@thistleknot
Copy link

Curious how to mask to train on outputs only if you think masking isn't needed.

@pjsample
Copy link

Curious how to mask to train on outputs only if you think masking isn't needed.

My understanding is that the output of the model (token logits) is causal by default, so there is no masking when the model is being trained autoregressively. For an idea on how to train the model look here: https://github.com/havenhq/mamba-chat/blob/main/train_mamba.py

@sentialx
Copy link

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

@junphine
Copy link

junphine commented Jan 2, 2024

I think if you want pad+mask to be effective, you need to do pre-training without using a full sentence in chunk

@sunningmbzuai
Copy link

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

@xtwigs
Copy link

xtwigs commented Jan 11, 2024

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.

On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).

Can someone verify if my thought process is accurate?

@sunningmbzuai
Copy link

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.

On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).

Can someone verify if my thought process is accurate?

Could you provide your change in mamba_ssm/ops/selective_scan_interface.py#301 ? That may help a lot.

Comment on lines +247 to +248
config = MambaConfig(**config_data, **kwargs)
model = cls(config, device=device, dtype=dtype)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you make this change with the kwargs?

@abdulfatir
Copy link

@normster @tridao @albertfgu I believe this feature would be very nice to have in a stable release. Can we work towards merging this into main and have it in the next stable release? I am happy to help in any way.

@@ -301,11 +305,11 @@ def mamba_inner_fn(
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True
C_proj_bias=None, mask=None, delta_softplus=True
):
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
Copy link
Contributor

@season0528 season0528 Feb 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some confusion about this line of code. Since MambaInnerFn doesn't provides parameters option for mask, it seems that it has no effect one the fwd and bwd pass.
Hence, how can the mask be applied to mark the sequence boundaries?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be wrong, but the PR modified Line 162 in the very same file, which is the definition of MambaInnerFn.forward method. Also, multiplications on Line 181 and 220 used mask.

Copy link
Contributor

@season0528 season0528 Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! You are right. I might missed this line of code on L162.
But I see in this PR, attention_mask is only used in the forward pass, and seem not to be used in the backward pass. So when I tried to feed batch data with left padding and masking (batch_size, seq_len, hidden_dim) into mamba block , it reported error. Has anyone encountered a similar error?

Err Msg:

  File "/blahblah/miniconda3/envs/dev/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/blahblah/miniconda3/envs/dev/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: function MambaInnerFnBackward returned an incorrect number of gradients (expected 16, got 15)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xtwigs pointed out that one can add None in the returned gradients to fix the issue. (For fellows above who wonder where to put None, I did it at the end of the tuple.) I agree with @xtwigs as we don't calculate gradient on the mask tensor.

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.

On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).

Can someone verify if my thought process is accurate?

Copy link
Contributor

@season0528 season0528 Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xtwigs pointed out that one can add None in the returned gradients to fix the issue. (For fellows above who wonder where to put None, I did it at the end of the tuple.) I agree with @xtwigs as we don't calculate gradient on the mask tensor.

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.
On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).
Can someone verify if my thought process is accurate?

I have tried this approach but encountered with CUDA OOM, even with much more GPUs and much smaller seq_len. (8x nodes, 64x A100 GPUs, and seq_len=512 for 1.4B mamba model)

@season0528
Copy link
Contributor

season0528 commented Feb 29, 2024

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.
On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).
Can someone verify if my thought process is accurate?

Could you provide your change in mamba_ssm/ops/selective_scan_interface.py#301 ? That may help a lot.

Hi, I also meet this error. Could you please provide some insight about how to fix this issue on the backward pass?

@season0528
Copy link
Contributor

season0528 commented Feb 29, 2024

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.

On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).

Can someone verify if my thought process is accurate?

I have verified your idea and it seems not to work and causes CUDA OOM.
Err Msg:

  File "/some_path/miniconda3/envs/final_dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/some_path/zigzagcai/devel/mamba/mamba_ssm/modules/mamba_simple.py", line 146, in forward
    out = mamba_inner_fn(
  File "/some_path/zigzagcai/devel/mamba/mamba_ssm/ops/selective_scan_interface.py", line 310, in mamba_inner_fn
    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  File "/some_path/miniconda3/envs/final_dev/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 97, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/some_path/zigzagcai/devel/mamba/mamba_ssm/ops/selective_scan_interface.py", line 221, in forward
    out, scan_intermediates, out_z = selective_scan_cuda.fwd(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 320.00 MiB (GPU 0; 79.33 GiB total capacity; 74.83 GiB already allocated; 313.81 MiB free; 77.13 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

My code changes:

diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py
index 35143ad..6792749 100644
--- a/mamba_ssm/ops/selective_scan_interface.py
+++ b/mamba_ssm/ops/selective_scan_interface.py
@@ -159,7 +159,7 @@ class MambaInnerFn(torch.autograd.Function):
     def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                 out_proj_weight, out_proj_bias,
                 A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
-                C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=1):
+                C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=0):
         """
              xz: (batch, dim, seqlen)
         """
@@ -298,7 +298,7 @@ class MambaInnerFn(torch.autograd.Function):
                 dout_proj_weight, dout_proj_bias,
                 dA, dB, dC, dD,
                 ddelta_bias if delta_bias is not None else None,
-                dB_proj_bias, dC_proj_bias, None)
+                dB_proj_bias, dC_proj_bias, None, None)

@enneamer
Copy link

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.
On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).
Can someone verify if my thought process is accurate?

I have verified your idea and it seems not to work and causes CUDA OOM. Err Msg:

  File "/some_path/miniconda3/envs/final_dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/hwfile/caizheng/devel/mamba/mamba_ssm/modules/mamba_simple.py", line 146, in forward
    out = mamba_inner_fn(
  File "/mnt/hwfile/caizheng/devel/mamba/mamba_ssm/ops/selective_scan_interface.py", line 310, in mamba_inner_fn
    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  File "/some_path/miniconda3/envs/final_dev/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 97, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/mnt/hwfile/caizheng/devel/mamba/mamba_ssm/ops/selective_scan_interface.py", line 221, in forward
    out, scan_intermediates, out_z = selective_scan_cuda.fwd(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 320.00 MiB (GPU 0; 79.33 GiB total capacity; 74.83 GiB already allocated; 313.81 MiB free; 77.13 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

My code changes:

diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py
index 35143ad..6792749 100644
--- a/mamba_ssm/ops/selective_scan_interface.py
+++ b/mamba_ssm/ops/selective_scan_interface.py
@@ -159,7 +159,7 @@ class MambaInnerFn(torch.autograd.Function):
     def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                 out_proj_weight, out_proj_bias,
                 A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
-                C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=1):
+                C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=0):
         """
              xz: (batch, dim, seqlen)
         """
@@ -298,7 +298,7 @@ class MambaInnerFn(torch.autograd.Function):
                 dout_proj_weight, dout_proj_bias,
                 dA, dB, dC, dD,
                 ddelta_bias if delta_bias is not None else None,
-                dB_proj_bias, dC_proj_bias, None)
+                dB_proj_bias, dC_proj_bias, None, None)

Is it using the default configuration defined in MambaConfig? The default values specify a huge network with 64 layers and embedding size of 2560. And checkpoint_lvl=0 disables checkpoint and then asks the forward pass to keep convolution and delta results in the GPU memory.

@season0528
Copy link
Contributor

season0528 commented Feb 29, 2024

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.
On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).
Can someone verify if my thought process is accurate?

I have verified your idea and it seems not to work and causes CUDA OOM. Err Msg:

  File "/some_path/miniconda3/envs/final_dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/some_path/zigzagcai/devel/mamba/mamba_ssm/modules/mamba_simple.py", line 146, in forward
    out = mamba_inner_fn(
  File "/some_path/zizgagcai/devel/mamba/mamba_ssm/ops/selective_scan_interface.py", line 310, in mamba_inner_fn
    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  File "/some_path/miniconda3/envs/final_dev/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 97, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/some_path/zigzagcai/devel/mamba/mamba_ssm/ops/selective_scan_interface.py", line 221, in forward
    out, scan_intermediates, out_z = selective_scan_cuda.fwd(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 320.00 MiB (GPU 0; 79.33 GiB total capacity; 74.83 GiB already allocated; 313.81 MiB free; 77.13 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

My code changes:

diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py
index 35143ad..6792749 100644
--- a/mamba_ssm/ops/selective_scan_interface.py
+++ b/mamba_ssm/ops/selective_scan_interface.py
@@ -159,7 +159,7 @@ class MambaInnerFn(torch.autograd.Function):
     def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                 out_proj_weight, out_proj_bias,
                 A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
-                C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=1):
+                C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=0):
         """
              xz: (batch, dim, seqlen)
         """
@@ -298,7 +298,7 @@ class MambaInnerFn(torch.autograd.Function):
                 dout_proj_weight, dout_proj_bias,
                 dA, dB, dC, dD,
                 ddelta_bias if delta_bias is not None else None,
-                dB_proj_bias, dC_proj_bias, None)
+                dB_proj_bias, dC_proj_bias, None, None)

Is it using the default configuration defined in MambaConfig? The default values specify a huge network with 64 layers and embedding size of 2560. And checkpoint_lvl=0 disables checkpoint and then asks the forward pass to keep convolution and delta results in the GPU memory.

No. I am evaluating mamba model with 1.4B parameter size, where layers = 48 and model dimension = 2048 are equivalent to the size on the repo page.
As we know from the mamba paper, when I set checkpoint_lvl=0, it will disable the recomputation of conv1d_out, delta in backward pass and store those values in GPU memory, which leads to much more memory usage.

Here is my experiment details:

  1. original 1.4B mamba model: works well with 1x node with 8x A100 GPUs
  2. 1.4B mamba model patched with this PR: encounters OOM even with 4x nodes with 32x A100 GPUs

@season0528
Copy link
Contributor

We are trying this PR because we want mamba to process packed sequence like what has been done in transformer-based models.
If we directly pad the sequence with zero, then a lot of computation will be wasted on meaningless padded tokens.
We just want to use mask to mark the meaningless padded tokens and let the computation focus on regular tokens.

@xtwigs
Copy link

xtwigs commented Feb 29, 2024

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.
On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).
Can someone verify if my thought process is accurate?

Could you provide your change in mamba_ssm/ops/selective_scan_interface.py#301 ? That may help a lot.

Hi, I also meet this error. Could you please provide some insight about how to fix this issue on the backward pass?

Hi! Sorry, but I haven't run into the OOM issue. I keep a fork of this code here in case it might help. (I also run this with checkpoint_lvl=0)

@season0528
Copy link
Contributor

season0528 commented Mar 4, 2024

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.
On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).
Can someone verify if my thought process is accurate?

Could you provide your change in mamba_ssm/ops/selective_scan_interface.py#301 ? That may help a lot.

Hi, I also meet this error. Could you please provide some insight about how to fix this issue on the backward pass?

Hi! Sorry, but I haven't run into the OOM issue. I keep a fork of this code here in case it might help. (I also run this with checkpoint_lvl=0)

Hello, thanks for the sharing! :D I might be wrong. But when I tried this branch and found that OOM still appeared in the backward pass, meanwhile python tests/test_padding.py leads to huge max L2 errors (padded): tensor(3442.7573) in the forward pass, which did not appear in Norman’s original PR and might indicated the left padding and masking not work in the forward pass.

Is there any reproducible test code snippet indicating left padding+ masking works in both fwd and bwd pass?
Thanks!

@xtwigs
Copy link

xtwigs commented Mar 4, 2024

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.
On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).
Can someone verify if my thought process is accurate?

Could you provide your change in mamba_ssm/ops/selective_scan_interface.py#301 ? That may help a lot.

Hi, I also meet this error. Could you please provide some insight about how to fix this issue on the backward pass?

Hi! Sorry, but I haven't run into the OOM issue. I keep a fork of this code here in case it might help. (I also run this with checkpoint_lvl=0)

Hello, thanks for the sharing! :D I might be wrong. But when I tried this branch and found that OOM still appeared in the backward pass, meanwhile python tests/test_padding.py leads to huge max L2 errors (padded): tensor(3442.7573) in the forward pass, which did not appear in Norman’s original PR and might indicated the left padding and masking not work in the forward pass.

Is there any reproducible test code snippet indicating left padding+ masking works in both fwd and bwd pass? Thanks!

Can you try this while disabling the dropout module added in the mamba simple code? (default was set to 0.1)

@season0528
Copy link
Contributor

season0528 commented Mar 11, 2024

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.
On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).
Can someone verify if my thought process is accurate?

Could you provide your change in mamba_ssm/ops/selective_scan_interface.py#301 ? That may help a lot.

Hi, I also meet this error. Could you please provide some insight about how to fix this issue on the backward pass?

Hi! Sorry, but I haven't run into the OOM issue. I keep a fork of this code here in case it might help. (I also run this with checkpoint_lvl=0)

Hello, thanks for the sharing! :D I might be wrong. But when I tried this branch and found that OOM still appeared in the backward pass, meanwhile python tests/test_padding.py leads to huge max L2 errors (padded): tensor(3442.7573) in the forward pass, which did not appear in Norman’s original PR and might indicated the left padding and masking not work in the forward pass.
Is there any reproducible test code snippet indicating left padding+ masking works in both fwd and bwd pass? Thanks!

Can you try this while disabling the dropout module added in the mamba simple code? (default was set to 0.1)

Hi xtwigs, many thanks for your reply!
I tried the latest code of your branch and found mamba block runnable without error, but when I run the test code, the max L2 errors are still relatively large.

@laulampaul
Copy link

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.
On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).
Can someone verify if my thought process is accurate?

Could you provide your change in mamba_ssm/ops/selective_scan_interface.py#301 ? That may help a lot.

Hi, I also meet this error. Could you please provide some insight about how to fix this issue on the backward pass?

Hi! Sorry, but I haven't run into the OOM issue. I keep a fork of this code here in case it might help. (I also run this with checkpoint_lvl=0)

Hello, thanks for the sharing! :D I might be wrong. But when I tried this branch and found that OOM still appeared in the backward pass, meanwhile python tests/test_padding.py leads to huge max L2 errors (padded): tensor(3442.7573) in the forward pass, which did not appear in Norman’s original PR and might indicated the left padding and masking not work in the forward pass.
Is there any reproducible test code snippet indicating left padding+ masking works in both fwd and bwd pass? Thanks!

Can you try this while disabling the dropout module added in the mamba simple code? (default was set to 0.1)

Hi xtwigs, many thanks for your reply! I tried the latest code of your branch and found mamba block runnable without error, but when I run the test code, the max L2 errors are still relatively large.

Hi zigzagcai , have you solved the problem of masking in the mamba block ?

@albertfgu albertfgu force-pushed the main branch 2 times, most recently from 6d45666 to 41d30ce Compare June 3, 2024 12:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.