Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support variable-length sequences for mamba block #244

Open
wants to merge 37 commits into
base: main
Choose a base branch
from

Conversation

zigzagcai
Copy link
Contributor

@zigzagcai zigzagcai commented Mar 14, 2024

Support variable-length sequences for mamba block via cu_seqlens in the forward pass and backward pass, similar to what has been done (such as cumulative sequences cu_seqlens or lower triangular block diagonal matrix attention mask) in flash attention varlen_fwd/varlen_bwd API.

We have tested that training with variable-length sequences on real world datasets can bring 2~4x speedup.

  • Why we need?
    High speedup and hardware utilization on real world datasets that we tested. Can be used to improve hardware utilization when you have variable-length sequences and you don't want to waste computing resources on meaningless padded tokens. Especially useful when you do mamba training on real world datasets, where length distribution varies much and large proportion of samples are short sequences. Last but not least, we ensure exact fwd/bwd numerical equality with padding approach.

  • How to use?
    Zero learning overhead, packed mamba API is similar to packed flash-attn API or packed mamba2 API. Just need to pack multiple variable-length sequences into one and additionally pass cu_seqlens into mamba forward pass.

Note:
We thank @wang-zerui for the python reference implementation and invaluable discussion on how to ensure numerical equality.
This is a joint work with @wang-zerui and @Dmovic and @ptxu78

Some related issues about mamba and flash-attn variable-length training:

  1. Variable input sequence length #236
  2. Question about Using Mamba/Mamba2 with Variable Input Lengths #356
  3. Question about does mamba support variable-length input or cu_seqlens like flash attention? #180
  4. no limitation on the inputs? #246 (comment)
  5. How did flash-attn compute attention for cu_seqlens Dao-AILab/flash-attention#850 (comment)
  6. Will attention_mask be extended to 3D? (concatenate short samples for efficient training) Dao-AILab/flash-attention#432 (comment)

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Mar 14, 2024

Hello @tridao @albertfgu

Thanks for the awesome work on mamba and it is really a strong competitor for transformer!

We have noticed some issues (#236, #180) stated that they have a need for training on variable-length sequences. But they can’t find functionalities such as attention_mask or cu_seqlens in mamba block, which are commonly used in transformer structure to support variable length training.

Also, in real world scenarios, length distribution of datasets varies much, simply padding token to maximum length would waste computing resources on the meaningless padded tokens.

So we implemented this PR and hope it helps!

@zigzagcai zigzagcai force-pushed the feat/add-cu_seqlens branch from 77e58cb to a78a9eb Compare March 15, 2024 09:31
@zigzagcai zigzagcai force-pushed the feat/add-cu_seqlens branch from aea08ca to 842bef5 Compare March 18, 2024 09:19
@EricPaul03
Copy link

Hello, it's great to see your input on variable length data. How can I use the method you provided? Is there any difference in results between it and padding?

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Mar 18, 2024

Hello, it's great to see your input on variable length data. How can I use the method you provided? Is there any difference in results between it and padding?

Thank you for your interest in this PR!
For the forward pass of mamba block, we have done numerical comparison between it and padding results, which are finally shown to be consistent. (numerical equality for forward pass has been verified)
For the backward pass, we decide to add some unit tests to show the consistency when we have bandwidth. (haven't verified numerical equality for backward pass)

Update (2024/03/19):
Numerical equality for both forward and backward pass have been validated.
In terms of training loss and accuracy, this PR can be numerically aligned with padding approach, but can also avoid wasting computation resources on the meaningless padded tokens.
When training on a sample dataset, using variable-length training can bring high speedup compared to padding.

@EricPaul03
Copy link

Hello, it's great to see your input on variable length data. How can I use the method you provided? Is there any difference in results between it and padding?

Thank you for your interest in this PR! For the forward pass of mamba block, we have done numerical comparison between it and padding results, which are finally shown to be consistent. For the backward pass, we decide to add some unite tests to show the consistency when we have bandwidth.

Thank you for your reply. Due to performance considerations, I would like to use bidirectional mamba. Should I wait for your updated code?

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Mar 19, 2024

Hello, it's great to see your input on variable length data. How can I use the method you provided? Is there any difference in results between it and padding?

Thank you for your interest in this PR! For the forward pass of mamba block, we have done numerical comparison between it and padding results, which are finally shown to be consistent. For the backward pass, we decide to add some unite tests to show the consistency when we have bandwidth.

Thank you for your reply. Due to performance considerations, I would like to use bidirectional mamba. Should I wait for your updated code?

Hi @EricPaul03 ,

@Dmovic has created unit test on the backward pass of mamba block with variable-length sequences, and the test results show numerical equality for both forward and backward pass in the scenarios of varlen inputs.

I haven't tried it with bidirectional mamba. But since it is numerical equivalent for the default unidirectional mamba, I think you can just give it a try!

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Mar 19, 2024

image

To give a simple example. What we originally pass into the original mamba block is an input with shape (batch_size=7, seq_len=10, hidden_dim)
Through this PR, we can instead pass into the variable-length mamba block with an input with shape (batch_size=1, seq_len=32, hidden_dim), where the original variable-length sequences are packed into one fixed-length sequence, with an additional parameter cu_seqlens to mark sequence boundaries.

From the above figure, we can clearly see that through this PR, mamba block can focus computing resources on variable-length sequences and avoid the overhead of meaningless padding tokens.

Variable-length training is very useful for optimizing the hardware utilization during training, and we know that the well-known flash attention has supported variable-length training via cu_seqlens.
Therefore, we believe that mamba, as a competitor of transformer, can improve its hardware utilization during training on real world datasets (the length distribution varies much between data samples) through this PR!

@EricPaul03
Copy link

image

To give a simple example. What we originally pass into the original mamba block is an input with shape (batch_size=7, seq_len=10, hidden_dim) Through this PR, we can instead pass into the enhanced mamba block with an input with shape (batch_size=1, seq_len=32, hidden_dim), where the original variable-length sequences are packed into one fixed-length sequence, with an additional parameter cu_seqlens to mark sequence boundaries.

From the above figure, we can clearly see that through this PR, mamba block can focus computing resources on variable-length sequences and avoid the overhead of meaningless padding tokens.

Variable-length training is very useful for optimizing the hardware utilization during training, and we know that the well-known flash attention has supported variable-length training via cu_seqlens.

Thank you for your answer. This is a great code that I will try to use for my project!

@EricPaul03
Copy link

EricPaul03 commented Mar 19, 2024

Sorry to bother you again, I would like to implement the same operation for bidirectional mamba. I would like to know if I also need to reset the value for cu_seqlens when flipping the propagation sequence to cope with the flipped sequence, and can these two share d_conv?
for example:
out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens )

out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd( conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus, cu_seqlens )# cu_seqlens should be changed??

ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b, cu_seqlens, d_conv) #the same d_conv ?

@EricPaul03
Copy link

Sorry to bother you again, I would like to implement the same operation for bidirectional mamba. I would like to know if I also need to reset the value for cu_seqlens when flipping the propagation sequence to cope with the flipped sequence, and can these two share d_conv? for example: out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens )

out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd( conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus, cu_seqlens )# cu_seqlens should be changed??

ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b, cu_seqlens, d_conv) #the same d_conv ?

I think I should divide conv1d_out, delta, etc. into subsequences and reverse each subsequence separately? (Instead of the entire sequence, use the same cu_seqlens?)

@junphine
Copy link

junphine commented Mar 21, 2024

I copy some method in MixerModel to help use this feature.

def unpad_input(self, hidden_states, attention_mask):
    hidden_states = rearrange(hidden_states, "b s ... -> (b s) ...")
    valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1)  # some time is eq(1)
    seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
    indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
    hidden_states = hidden_states[indices].unsqueeze(0)
    return hidden_states, indices, cu_seqlens, max_seqlen_in_batch

def pad_input(self, hidden_states, indices, batch, seqlen):
    """        
    :param hidden_states: Shape is [L,H] not [B,L,H]
    :param indices: from unpad_input return indices
    :param batch: 
    :param seqlen:  from unpad_input return max_seqlen_in_batch
    :return: 
    """
    output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device,dtype=hidden_states.dtype)
    output[indices] = hidden_states
    return rearrange(output, '(b s) ... -> b s ...', b=batch)

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Mar 21, 2024

Sorry to bother you again, I would like to implement the same operation for bidirectional mamba. I would like to know if I also need to reset the value for cu_seqlens when flipping the propagation sequence to cope with the flipped sequence, and can these two share d_conv? for example: out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens )
out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd( conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus, cu_seqlens )# cu_seqlens should be changed??
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b, cu_seqlens, d_conv) #the same d_conv ?

I think I should divide conv1d_out, delta, etc. into subsequences and reverse each subsequence separately? (Instead of the entire sequence, use the same cu_seqlens?)

For bidirectional mamba, you need to pass in the reverse_cu_seqlens to the reverse pass like that,

           out_rev = self.mamba_rev(
                hidden_states.flip(dims=(1,)),  # Flip along the sequence length dimension
                cu_seqlens=reverse_cu_seqlens, # Reverse cu_seqlens
                inference_params=inference_params
            ).flip(dims=(1,))  # Flip back for combining with forward hidden states

For example, if you have cu_seqlens = torch.tensor([0, 5, 15, 18, 19, 21, 26, 32]), the reverse_cu_seqlens should be reverse_cu_seqlens = tensor([ 0, 6, 11, 13, 14, 17, 27, 32]), which represents the position in the reverse pass that we need to reset hidden_states.

We can calculate reverse_cu_seqlens as following formula,

reverse_cu_seqlens = torch.cumsum(torch.cat((torch.tensor([0]), (cu_seqlens[1:]-cu_seqlens[:-1]).flip(dims=(0,))), dim=0), dim=0)

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Mar 21, 2024

Sorry to bother you again, I would like to implement the same operation for bidirectional mamba. I would like to know if I also need to reset the value for cu_seqlens when flipping the propagation sequence to cope with the flipped sequence, and can these two share d_conv? for example: out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens )
out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd( conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus, cu_seqlens )# cu_seqlens should be changed??
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b, cu_seqlens, d_conv) #the same d_conv ?

I think I should divide conv1d_out, delta, etc. into subsequences and reverse each subsequence separately? (Instead of the entire sequence, use the same cu_seqlens?)

I think you might not need to divide these items into subsequences. All you need is to pass in the reverse_cu_seqlens to the reverse pass, and finally enjoys the benefits of both bidirectional and variable-length training.

For combining the benefits of bidirectional mamba and this PR's variable-length sequences, I drew my graphical understanding here,
image

The mechanism can be simply viewed as that when scanning bidirectionally, hidden_states need to be reset on sequence boundaries of both directions.

@zigzagcai zigzagcai force-pushed the feat/add-cu_seqlens branch 2 times, most recently from c1b68de to 5b024c5 Compare July 22, 2024 08:50
@zigzagcai zigzagcai force-pushed the feat/add-cu_seqlens branch from 8f472c3 to b69b957 Compare July 22, 2024 08:57
@zigzagcai
Copy link
Contributor Author

zigzagcai commented Jul 22, 2024

Update from 2024/07/22:

  • I have migrated to tridao's latest implementation of variable length causal_conv1d ( which requires causal-conv1d>=1.4.0) in this commit. It is awesome that all the variable length features in mamba are powered by CUDA kernels natively. Much faster!

  • Exactly the unified API with mamba2 and flash-attn. (mamba, mamba2, and flash-attn all use cu_seqlens to power variable length training) Much easier to use!

  • The unit test shows that the variable length mamba block has exact mathematical equality both in the forward pass and backward pass.

python tests/ops/test_mamba_cu_seqlens_equivalence.py

Generate random cu_seqlens = [0, 116, 155, 349, 479, 674, 864, 881, 1024]
max diff for output in varlen_mamba fwd pass: 4.470348358154297e-08
mean diff for output in varlen_mamba fwd pass: 5.5261386577853955e-09
max diff for A_log in varlen_mamba bwd pass: 6.239861249923706e-08
mean diff for A_log in varlen_mamba bwd pass: 5.321690865756068e-10
max diff for D in varlen_mamba bwd pass: 6.318092346191406e-06
mean diff for D in varlen_mamba bwd pass: 6.176169335958548e-07
max diff for in_proj.weight in varlen_mamba bwd pass: 1.9073486328125e-05
mean diff for in_proj.weight in varlen_mamba bwd pass: 1.098805341825937e-06
max diff for conv1d.weight in varlen_mamba bwd pass: 5.662441253662109e-06
mean diff for conv1d.weight in varlen_mamba bwd pass: 8.699786349097849e-07
max diff for conv1d.bias in varlen_mamba bwd pass: 1.0013580322265625e-05
mean diff for conv1d.bias in varlen_mamba bwd pass: 1.4602501323679462e-06
max diff for x_proj.weight in varlen_mamba bwd pass: 3.6954879760742188e-06
mean diff for x_proj.weight in varlen_mamba bwd pass: 2.984411295869904e-08
max diff for dt_proj.weight in varlen_mamba bwd pass: 8.731149137020111e-09
mean diff for dt_proj.weight in varlen_mamba bwd pass: 3.4094516099258954e-10
max diff for dt_proj.bias in varlen_mamba bwd pass: 2.60770320892334e-08
mean diff for dt_proj.bias in varlen_mamba bwd pass: 2.458180992093162e-09
max diff for out_proj.weight in varlen_mamba bwd pass: 5.7220458984375e-06
mean diff for out_proj.weight in varlen_mamba bwd pass: 2.629302855439164e-07
pytest tests/

============================= test session starts ==============================
platform linux -- Python 3.10.14, pytest-8.3.1, pluggy-1.5.0
rootdir: /dev/varlen_mamba
plugins: typeguard-3.0.2
collected 392 items

tests/ops/test_selective_scan.py ....................                    [  5%]
tests/ops/test_selective_scan_var_len.py ............                    [  8%]
tests/ops/triton/test_layernorm_gated.py .s.s.s.s.s.s.....s.s.s.s.s.s... [ 16%]
..s.s.s.s.s.s.....s.s.s.s.s.s....                                        [ 24%]
tests/ops/triton/test_selective_state_update.py ........................ [ 30%]
........................................................................ [ 48%]
........................................................................ [ 67%]
........................................................................ [ 85%]
..............................                                           [ 93%]
tests/ops/triton/test_ssd.py ........................                    [ 99%]
tests/test_generation.py ..                                              [100%]

================= 368 passed, 24 skipped in 256.74s (0:04:16) ==================

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Jul 22, 2024

Dear authors,

@tridao @albertfgu Firstly, thanks for the awesome work on theoretical analysis and code development of mamba, mamba2, and other series of state space models!

Currently, many users (#356, #236, #180) expect mamba to natively support variable-length training (just like what flash-attn and mamba2 have done) to utilize hardware efficiency, so we tried to develop this feature.

In this PR:
(1) We provide the unified API interface with mamba2 and flash-attn to support variable-length training. (via cu_seqlens)
(2) Variable length mamba is natively powered by causal_conv1d and selective scan CUDA kernels.

So, could this PR would be reviewed and merged as a feature for mamba if possible? Thanks!

@zigzagcai zigzagcai changed the title Support variable-length sequences for mamba block [Feature] Support variable-length sequences for mamba block Jul 22, 2024
@zigzagcai
Copy link
Contributor Author

zigzagcai commented Aug 9, 2024

It's great to see that there already one paper/project (Is Mamba Compatible with Trajectory Optimization in Offline Reinforcement Learning, NeurIPS'24) adopting our code in the area of offline Reinforcement Learning.

Link:
https://arxiv.org/pdf/2405.12094

@JindongJiang
Copy link

Hi @zigzagcai, thank you for the great work! I tried to install your version but encountered the selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol problem. Does it also occur to you when you test the code?

The full pipeline I did is the following:

# (optionally) clone causal-conv1d, also tried pip install causal-conv1d==1.4.0
git clone https://github.com/Dao-AILab/causal-conv1d
cd causal-conv1d
git checkout v1.4.0
pip install -e .

cd ..
# clone and checkout your pr
git clone https://github.com/state-spaces/mamba
cd mamba
git fetch origin pull/244/head:pr-244
git checkout pr-244
pip install -e .

Tried installing with pytorch 2.4, 2.1, cuda 12.5, 12.1. All settings have the same problem:

> python tests/ops/test_mamba_cu_seqlens_equivalence.py

Traceback (most recent call last):
  File "/.../mamba/tests/ops/test_mamba_cu_seqlens_equivalence.py", line 5, in <module>
    from mamba_ssm.modules.mamba_simple import Mamba
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/__init__.py", line 3, in <module>
    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/selective_scan_interface.py", line 16, in <module>
    import selective_scan_cuda
ImportError: /usr/local/lib/python3.10/dist-packages/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops10zeros_like4callERKNS_6TensorESt8optionalIN3c1010ScalarTypeEES5_INS6_6LayoutEES5_INS6_6DeviceEES5_IbES5_INS6_12MemoryFormatEE

Additionally, I also found that the installed causal-conv1d and mamba-ssm doesn't seem to recognize each other, because when I do the following, it shows that causal-conv1d is required by nothing:

>pip show causal-conv1d

Name: causal-conv1d
Version: 1.4.0
Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
Home-page: https://github.com/Dao-AILab/causal-conv1d
Author: Tri Dao
Author-email: [email protected]
License:
Location: /usr/local/lib/python3.10/dist-packages
Requires: ninja, packaging, torch
Required-by: (empty here)

Similarly, mamba_ssm does not require causal-conv1d:

> pip show mamba-ssm

Name: mamba_ssm
Version: 2.2.2
Summary: Mamba state-space model
Home-page:
Author:
Author-email: Tri Dao <[email protected]>, Albert Gu <[email protected]>
...
Location: /usr/local/lib/python3.10/dist-packages
Requires: einops, ninja, packaging, setuptools, torch, transformers, triton (causal-conv1d is not here)
Required-by:

If this issue does't occur to you, could you provide the installing script you are using for the most up-to-date version? Thanks!

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Aug 11, 2024

Hi @JindongJiang ,

I share my minimum reproducing steps here.

  • The hardware and software info:
HW: A800/A100
Driver: CUDA 11.8
  • Steps to setup envs:
conda create -n mamba_dev python=3.10
conda activate mamba_dev
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
pip install causal-conv1d==1.4.0
pip install einops huggingface-hub transformers triton pytest 
git clone https://github.com/zigzagcai/varlen_mamba.git --branch feat/add-cu_seqlens
cd varlen_mamba
pip install -e .
  • Run tests:
pytest tests/

============================= test session starts ==============================
platform linux -- Python 3.10.14, pytest-8.3.2, pluggy-1.5.0
rootdir: /blahblah/zigzagcai/varlen_mamba
plugins: typeguard-3.0.2
collected 392 items

tests/ops/test_selective_scan.py ....................                    [  5%]
tests/ops/test_selective_scan_var_len.py ............                    [  8%]
tests/ops/triton/test_layernorm_gated.py .s.s.s.s.s.s.....s.s.s.s.s.s... [ 16%]
..s.s.s.s.s.s.....s.s.s.s.s.s....                                        [ 24%]
tests/ops/triton/test_selective_state_update.py ........................ [ 30%]
........................................................................ [ 48%]
........................................................................ [ 67%]
........................................................................ [ 85%]
..............................                                           [ 93%]
tests/ops/triton/test_ssd.py ........................                    [ 99%]
tests/test_generation.py ..                                              [100%]

================= 368 passed, 24 skipped in 183.78s (0:03:03) ==================
python tests/ops/test_mamba_cu_seqlens_equivalence.py

Generate random cu_seqlens = [0, 5, 84, 182, 202, 284, 796, 836, 1024]
max diff for output in varlen_mamba fwd pass: 6.407499313354492e-07
mean diff for output in varlen_mamba fwd pass: 3.794611203034037e-08
max diff for A_log in varlen_mamba bwd pass: 6.705522537231445e-08
mean diff for A_log in varlen_mamba bwd pass: 6.687657094772703e-10
max diff for D in varlen_mamba bwd pass: 4.76837158203125e-06
mean diff for D in varlen_mamba bwd pass: 6.003104999763309e-07
max diff for in_proj.weight in varlen_mamba bwd pass: 1.9073486328125e-05
mean diff for in_proj.weight in varlen_mamba bwd pass: 1.0953947366942884e-06
max diff for conv1d.weight in varlen_mamba bwd pass: 5.364418029785156e-06
mean diff for conv1d.weight in varlen_mamba bwd pass: 8.792806056590052e-07
max diff for conv1d.bias in varlen_mamba bwd pass: 7.867813110351562e-06
mean diff for conv1d.bias in varlen_mamba bwd pass: 1.4787228792556562e-06
max diff for x_proj.weight in varlen_mamba bwd pass: 5.029141902923584e-06
mean diff for x_proj.weight in varlen_mamba bwd pass: 3.1919995535645285e-08
max diff for dt_proj.weight in varlen_mamba bwd pass: 1.3300450518727303e-08
mean diff for dt_proj.weight in varlen_mamba bwd pass: 3.616623112101536e-10
max diff for dt_proj.bias in varlen_mamba bwd pass: 3.166496753692627e-08
mean diff for dt_proj.bias in varlen_mamba bwd pass: 2.6783406603669846e-09
max diff for out_proj.weight in varlen_mamba bwd pass: 6.67572021484375e-06
mean diff for out_proj.weight in varlen_mamba bwd pass: 2.693569740586099e-07

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Aug 11, 2024

Hi @zigzagcai, thank you for the great work! I tried to install your version but encountered the selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol problem. Does it also occur to you when you test the code?

The full pipeline I did is the following:

# (optionally) clone causal-conv1d, also tried pip install causal-conv1d==1.4.0
git clone https://github.com/Dao-AILab/causal-conv1d
cd causal-conv1d
git checkout v1.4.0
pip install -e .

cd ..
# clone and checkout your pr
git clone https://github.com/state-spaces/mamba
cd mamba
git fetch origin pull/244/head:pr-244
git checkout pr-244
pip install -e .

Tried installing with pytorch 2.4, 2.1, cuda 12.5, 12.1. All settings have the same problem:

> python tests/ops/test_mamba_cu_seqlens_equivalence.py

Traceback (most recent call last):
  File "/.../mamba/tests/ops/test_mamba_cu_seqlens_equivalence.py", line 5, in <module>
    from mamba_ssm.modules.mamba_simple import Mamba
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/__init__.py", line 3, in <module>
    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/selective_scan_interface.py", line 16, in <module>
    import selective_scan_cuda
ImportError: /usr/local/lib/python3.10/dist-packages/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops10zeros_like4callERKNS_6TensorESt8optionalIN3c1010ScalarTypeEES5_INS6_6LayoutEES5_INS6_6DeviceEES5_IbES5_INS6_12MemoryFormatEE

Additionally, I also found that the installed causal-conv1d and mamba-ssm doesn't seem to recognize each other, because when I do the following, it shows that causal-conv1d is required by nothing:

>pip show causal-conv1d

Name: causal-conv1d
Version: 1.4.0
Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
Home-page: https://github.com/Dao-AILab/causal-conv1d
Author: Tri Dao
Author-email: [email protected]
License:
Location: /usr/local/lib/python3.10/dist-packages
Requires: ninja, packaging, torch
Required-by: (empty here)

Similarly, mamba_ssm does not require causal-conv1d:

> pip show mamba-ssm

Name: mamba_ssm
Version: 2.2.2
Summary: Mamba state-space model
Home-page:
Author:
Author-email: Tri Dao <[email protected]>, Albert Gu <[email protected]>
...
Location: /usr/local/lib/python3.10/dist-packages
Requires: einops, ninja, packaging, setuptools, torch, transformers, triton (causal-conv1d is not here)
Required-by:

If this issue does't occur to you, could you provide the installing script you are using for the most up-to-date version? Thanks!

Hi, @JindongJiang

Firstly, Thanks for your interest in this PR!

  1. I also have ever met the error like selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol when I tried to build mamba from source.
    I searched about it in mamba issues and It seemed to be a known issue Import Error #169 (comment)
    I think it might be some problem with conda dirty cached files. I always re-create a new conda env from scratch to workaround it. And you might also try the approach from the above comment. (uninstall and reinstall with --no-cache-dir)

  2. It might be some problem with the project.toml build system. I noticed that this file was added just recently in the commit 323db26
    You can just delete the project.toml and manually install the depedencies to workaround it.

pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
pip install causal-conv1d==1.4.0
pip install einops huggingface-hub transformers triton pytest 

@JindongJiang
Copy link

JindongJiang commented Aug 11, 2024

Hi @zigzagcai, thank you very much for the help. Interestingly, deleting the project.toml solve the undefined symbol issue. Now I can successfully run python tests/ops/test_mamba_cu_seqlens_equivalence.py with CUDA 12.5 and PyTorch 2.4.0. However, the reported diff seem to be quite large for in_proj.weight and out_proj.weight. Rerun more times will further trigger the assert torch.allclose(mamba_grad[name], mamba_ref_grad[name], rtol=rtol, atol=atol) AssertionError.

> python tests/ops/test_mamba_cu_seqlens_equivalence.py
/lustre/fs2/portfolios/nvr/users/jindongj/Documents/Programming/PyTorch/tmp/mamba_varlen/varlen_mamba/mamba_ssm/ops/selective_scan_interface.py:169: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
/lustre/fs2/portfolios/nvr/users/jindongj/Documents/Programming/PyTorch/tmp/mamba_varlen/varlen_mamba/mamba_ssm/ops/selective_scan_interface.py:265: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, dout):
/lustre/fs2/portfolios/nvr/users/jindongj/Documents/Programming/PyTorch/tmp/mamba_varlen/varlen_mamba/mamba_ssm/ops/triton/layer_norm.py:986: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(
/lustre/fs2/portfolios/nvr/users/jindongj/Documents/Programming/PyTorch/tmp/mamba_varlen/varlen_mamba/mamba_ssm/ops/triton/layer_norm.py:1045: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, dout, *args):
/lustre/fs2/portfolios/nvr/users/jindongj/Documents/Programming/PyTorch/tmp/mamba_varlen/varlen_mamba/mamba_ssm/distributed/tensor_parallel.py:26: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
/lustre/fs2/portfolios/nvr/users/jindongj/Documents/Programming/PyTorch/tmp/mamba_varlen/varlen_mamba/mamba_ssm/distributed/tensor_parallel.py:62: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, grad_output):
/lustre/fs2/portfolios/nvr/users/jindongj/Documents/Programming/PyTorch/tmp/mamba_varlen/varlen_mamba/mamba_ssm/ops/triton/ssd_combined.py:758: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
/lustre/fs2/portfolios/nvr/users/jindongj/Documents/Programming/PyTorch/tmp/mamba_varlen/varlen_mamba/mamba_ssm/ops/triton/ssd_combined.py:836: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, dout, *args):
Generate random cu_seqlens = [0, 360, 587, 663, 696, 731, 783, 1016, 1024]
max diff for output in varlen_mamba fwd pass: 5.148351192474365e-06
mean diff for output in varlen_mamba fwd pass: 3.023584440597915e-07
max diff for A_log in varlen_mamba bwd pass: 1.1734664440155029e-07
mean diff for A_log in varlen_mamba bwd pass: 4.355594551697095e-09
max diff for D in varlen_mamba bwd pass: 2.09808349609375e-05
mean diff for D in varlen_mamba bwd pass: 2.792159648379311e-06
max diff for in_proj.weight in varlen_mamba bwd pass: 0.002126932144165039 (larger than the others)
mean diff for in_proj.weight in varlen_mamba bwd pass: 4.3531857954803854e-05
max diff for conv1d.weight in varlen_mamba bwd pass: 2.5033950805664062e-05
mean diff for conv1d.weight in varlen_mamba bwd pass: 3.934956112061627e-06
max diff for conv1d.bias in varlen_mamba bwd pass: 4.76837158203125e-05
mean diff for conv1d.bias in varlen_mamba bwd pass: 6.707944066874916e-06
max diff for x_proj.weight in varlen_mamba bwd pass: 9.655207395553589e-05
mean diff for x_proj.weight in varlen_mamba bwd pass: 1.2083935416740132e-06
max diff for dt_proj.weight in varlen_mamba bwd pass: 2.1980376914143562e-06
mean diff for dt_proj.weight in varlen_mamba bwd pass: 1.9807760764933846e-08
max diff for dt_proj.bias in varlen_mamba bwd pass: 2.2165477275848389e-07
mean diff for dt_proj.bias in varlen_mamba bwd pass: 1.6080232256854288e-08
max diff for out_proj.weight in varlen_mamba bwd pass: 0.001113295555114746 (larger than the others)
mean diff for out_proj.weight in varlen_mamba bwd pass: 2.5028268737514736e-06

Beside the pytorch and cuda version, I used the same setup as you suggested:

pip install causal-conv1d==1.4.0
pip install einops huggingface-hub transformers triton pytest 
git clone https://github.com/zigzagcai/varlen_mamba.git --branch feat/add-cu_seqlens
cd varlen_mamba
pip install -e .

I will now try using cuda 11.8 as well and will let you know if I get the same problem.

@JindongJiang
Copy link

JindongJiang commented Aug 11, 2024

Hi @zigzagcai, I am back with cuda 11.8 results, problem still exist. This time I am (almost) fully following your setup script:

# I first pulled and started a cuda 11.8 containter
conda create -n mamba_dev python=3.10
conda activate mamba_dev
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
pip install causal-conv1d==1.4.0
pip install einops huggingface-hub transformers triton pytest 
git clone https://github.com/zigzagcai/varlen_mamba.git --branch feat/add-cu_seqlens
cd varlen_mamba
pip install --no-build-isolation -e . 

Only difference is that I have to do --no-build-isolation for the final pip, otherwise will get

RuntimeError:
      The detected CUDA version (11.8) mismatches the version that was used to compile
      PyTorch (12.1). Please make sure to use the same CUDA versions.

Complete results and env:

> nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

> python -c 'import torch; print(torch.version.cuda)'
11.8

> python -c 'import torch; print(torch.__version__)'
2.1.0+cu118

> python tests/ops/test_mamba_cu_seqlens_equivalence.py
Generate random cu_seqlens = [0, 47, 108, 301, 355, 654, 710, 1018, 1024]
max diff for output in varlen_mamba fwd pass: 5.170702934265137e-06
mean diff for output in varlen_mamba fwd pass: 2.9561221026597195e-07
max diff for A_log in varlen_mamba bwd pass: 3.259629011154175e-07
mean diff for A_log in varlen_mamba bwd pass: 5.464801056120905e-09
max diff for D in varlen_mamba bwd pass: 2.0265579223632812e-05
mean diff for D in varlen_mamba bwd pass: 2.7891701392945834e-06
max diff for in_proj.weight in varlen_mamba bwd pass: 0.0019719600677490234 (still quite large)
mean diff for in_proj.weight in varlen_mamba bwd pass: 4.3339219701010734e-05
max diff for conv1d.weight in varlen_mamba bwd pass: 2.4557113647460938e-05
mean diff for conv1d.weight in varlen_mamba bwd pass: 3.905551238858607e-06
max diff for conv1d.bias in varlen_mamba bwd pass: 4.1484832763671875e-05
mean diff for conv1d.bias in varlen_mamba bwd pass: 6.763219062122516e-06
max diff for x_proj.weight in varlen_mamba bwd pass: 8.52346420288086e-05
mean diff for x_proj.weight in varlen_mamba bwd pass: 1.2547155847641989e-06
max diff for dt_proj.weight in varlen_mamba bwd pass: 8.707866072654724e-07
mean diff for dt_proj.weight in varlen_mamba bwd pass: 1.8839402926573712e-08
max diff for dt_proj.bias in varlen_mamba bwd pass: 2.5704503059387207e-07
mean diff for dt_proj.bias in varlen_mamba bwd pass: 1.7761198733978745e-08
max diff for out_proj.weight in varlen_mamba bwd pass: 0.0009202957153320312 (still quite large)
mean diff for out_proj.weight in varlen_mamba bwd pass: 2.580519321782049e-06

It is actually quite surprising that the big discrepancies only happen at the beginning and end: in_proj and out_proj. Could you provide some comments on this? Thanks!

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Aug 15, 2024

Hi @JindongJiang ,

The error below is caused by the project.toml build system, and it is also a popular encountered issue in vllm project. vllm-project/vllm#129 (comment)

RuntimeError:
      The detected CUDA version (11.8) mismatches the version that was used to compile
      PyTorch (12.1). Please make sure to use the same CUDA versions.

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Aug 15, 2024

I just revert the recent merge commit in 0a15f1d
Therefore the project.toml is removed and it should be okay to pip install -e .

Could you please re-try my branch? I just re-tested the code on my env and it is okay.

conda create -n mamba_dev python=3.10
conda activate mamba_dev
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
pip install causal-conv1d==1.4.0
pip install einops huggingface-hub transformers triton pytest 
git clone https://github.com/zigzagcai/varlen_mamba.git --branch feat/add-cu_seqlens
cd varlen_mamba
pip install -e .

The test results

python tests/ops/test_mamba_cu_seqlens_equivalence.py

Generate random cu_seqlens = [0, 23, 124, 465, 501, 678, 847, 1000, 1024]
max diff for output in varlen_mamba fwd pass: 6.854534149169922e-07
mean diff for output in varlen_mamba fwd pass: 3.783769386700442e-08
max diff for A_log in varlen_mamba bwd pass: 6.51925802230835e-08
mean diff for A_log in varlen_mamba bwd pass: 7.246340749667013e-10
max diff for D in varlen_mamba bwd pass: 4.410743713378906e-06
mean diff for D in varlen_mamba bwd pass: 6.200841653480893e-07
max diff for in_proj.weight in varlen_mamba bwd pass: 2.002716064453125e-05
mean diff for in_proj.weight in varlen_mamba bwd pass: 1.0927163884844049e-06
max diff for conv1d.weight in varlen_mamba bwd pass: 5.7220458984375e-06
mean diff for conv1d.weight in varlen_mamba bwd pass: 8.621824463261873e-07
max diff for conv1d.bias in varlen_mamba bwd pass: 9.775161743164062e-06
mean diff for conv1d.bias in varlen_mamba bwd pass: 1.4727456800756045e-06
max diff for x_proj.weight in varlen_mamba bwd pass: 7.62939453125e-06
mean diff for x_proj.weight in varlen_mamba bwd pass: 3.4194002296317194e-08
max diff for dt_proj.weight in varlen_mamba bwd pass: 1.1408701539039612e-08
mean diff for dt_proj.weight in varlen_mamba bwd pass: 4.2962927659928596e-10
max diff for dt_proj.bias in varlen_mamba bwd pass: 4.516914486885071e-08
mean diff for dt_proj.bias in varlen_mamba bwd pass: 3.3461828863323717e-09
max diff for out_proj.weight in varlen_mamba bwd pass: 6.67572021484375e-06
mean diff for out_proj.weight in varlen_mamba bwd pass: 2.732678581196524e-07

FYI. My local envs (including cuda version and pip packages):

nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0
pip list

Package                       Version               Editable project location
----------------------------- --------------------- ---------------------------------
alabaster                     0.7.13
appdirs                       1.4.4
attrs                         23.1.0
Babel                         2.12.1
cattrs                        23.1.2
causal-conv1d                 1.4.0
certifi                       2022.12.7
charset-normalizer            2.1.1
coloredlogs                   15.0.1
einops                        0.8.0
environs                      11.0.0
esbonio                       0.16.1
exceptiongroup                1.2.2
filelock                      3.13.1
fsspec                        2024.2.0
huggingface-hub               0.24.5
humanfriendly                 10.0
humanize                      4.9.0
idna                          3.4
imagesize                     1.4.1
iniconfig                     2.0.0
Jinja2                        3.1.3
lsprotocol                    2023.0.0a3
mamba_ssm                     2.2.2                 /blahblah/zigzagcai/varlen_mamba
MarkupSafe                    2.1.5
marshmallow                   3.21.1
mpmath                        1.3.0
multiprocessing-logging       0.3.4
networkx                      3.2.1
ninja                         1.11.1.1
numpy                         1.26.3
packaging                     24.1
pillow                        10.2.0
pip                           24.2
pluggy                        1.5.0
pygls                         1.0.2
pyspellchecker                0.7.2
pytest                        8.3.2
python-dotenv                 1.0.1
PyYAML                        6.0.2
regex                         2024.7.24
requests                      2.28.1
safetensors                   0.4.4
setuptools                    72.1.0
shared-memory-dict            0.7.2
snowballstemmer               2.2.0
Sphinx                        7.2.5
sphinxcontrib-applehelp       1.0.7
sphinxcontrib-devhelp         1.0.5
sphinxcontrib-htmlhelp        2.0.4
sphinxcontrib-jsmath          1.0.1
sphinxcontrib-qthelp          1.0.6
sphinxcontrib-serializinghtml 1.1.9
sympy                         1.12
tokenizers                    0.19.1
tomli                         2.0.1
torch                         2.1.0+cu118
torch-tb-profiler             0.4.1
torchaudio                    2.1.0+cu118
torchdata                     0.7.1.dev20240618+cpu
torchvision                   0.16.0+cu118
tqdm                          4.66.5
transformers                  4.44.0
triton                        2.1.0
typeguard                     3.0.2
typing_extensions             4.9.0
UltraDict                     0.0.6
urllib3                       1.26.13
wheel                         0.43.0
zstandard                     0.22.0

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Aug 15, 2024

BTW. @JindongJiang Which model of GPU are you using, A100, H100 or others? This way I can have better knowledge about your software and hardware environment.

@JindongJiang
Copy link

Hi @zigzagcai , thank you very much for the updates and new commits. I will test the new setup. I got the above results using A100.

@zigzagcai zigzagcai force-pushed the feat/add-cu_seqlens branch from 0a15f1d to cda4b5a Compare August 15, 2024 06:29
@JindongJiang
Copy link

Hi @zigzagcai , it seems that the grad discrepancy only exist when I use docker image in slurm. I have two ways to run the experiments:

  1. With conda env and without docker, works fine
# branch installed in mamba_env with project.toml removed 
> srun -A $ACCOUNT -t 00:05:00 --partition $PARTITION --job-name job_name --gpus 1 bash -c 'source activate mamba_env; python tests/ops/test_mamba_cu_seqlens_equivalence.py'

# Results
Generate random cu_seqlens = [0, 122, 132, 373, 545, 620, 958, 966, 1024]
max diff for output in varlen_mamba fwd pass: 6.556510925292969e-07
mean diff for output in varlen_mamba fwd pass: 3.8284465375681975e-08
max diff for A_log in varlen_mamba bwd pass: 2.3748725652694702e-08
mean diff for A_log in varlen_mamba bwd pass: 6.788294371062875e-10
max diff for D in varlen_mamba bwd pass: 4.291534423828125e-06
mean diff for D in varlen_mamba bwd pass: 6.210023002495291e-07
max diff for in_proj.weight in varlen_mamba bwd pass: 1.8596649169921875e-05
mean diff for in_proj.weight in varlen_mamba bwd pass: 1.1082604487455683e-06
max diff for conv1d.weight in varlen_mamba bwd pass: 5.125999450683594e-06
mean diff for conv1d.weight in varlen_mamba bwd pass: 8.711153896001633e-07
max diff for conv1d.bias in varlen_mamba bwd pass: 7.808208465576172e-06
mean diff for conv1d.bias in varlen_mamba bwd pass: 1.5153941603784915e-06
max diff for x_proj.weight in varlen_mamba bwd pass: 2.637505531311035e-06
mean diff for x_proj.weight in varlen_mamba bwd pass: 3.0485256985457454e-08
max diff for dt_proj.weight in varlen_mamba bwd pass: 8.381903171539307e-09
mean diff for dt_proj.weight in varlen_mamba bwd pass: 3.972099316129629e-10
max diff for dt_proj.bias in varlen_mamba bwd pass: 2.514570951461792e-08
mean diff for dt_proj.bias in varlen_mamba bwd pass: 2.534548571020423e-09
max diff for out_proj.weight in varlen_mamba bwd pass: 5.7220458984375e-06
mean diff for out_proj.weight in varlen_mamba bwd pass: 2.7322744244884234e-07
  1. With docker image and varlen_mamba installed in system python. Not working:
# branch installed for /usr/bin/python
> srun -A $ACCOUNT -t 00:05:00 --partition $PARTITION --job-name job_name --container-image=/path_to_image \
--gpus 1 bash -c 'python /path_to/tests/ops/test_mamba_cu_seqlens_equivalence.py'

# Results
Generate random cu_seqlens = [0, 81, 82, 139, 328, 377, 569, 724, 1024]
max diff for output in varlen_mamba fwd pass: 5.111098289489746e-06
mean diff for output in varlen_mamba fwd pass: 3.1440134762306116e-07
max diff for A_log in varlen_mamba bwd pass: 2.738088369369507e-07
mean diff for A_log in varlen_mamba bwd pass: 6.286445142222874e-09
max diff for D in varlen_mamba bwd pass: 2.002716064453125e-05
mean diff for D in varlen_mamba bwd pass: 2.825587216648273e-06
max diff for in_proj.weight in varlen_mamba bwd pass: 0.0038983821868896484
mean diff for in_proj.weight in varlen_mamba bwd pass: 4.323392204241827e-05
Traceback (most recent call last):
  File "/path_to/tests/ops/test_mamba_cu_seqlens_equivalence.py", line 125, in <module>
    main()
  File "/path_to/tests/ops/test_mamba_cu_seqlens_equivalence.py", line 122, in main
    assert torch.allclose(mamba_grad[name], mamba_ref_grad[name], rtol=rtol, atol=atol)
AssertionError

Thank you for your help again! I think the problem is not in the implementation then. I will use conda without docker for now.

@zigzagcai
Copy link
Contributor Author

zigzagcai commented Aug 21, 2024

Hi @zigzagcai , it seems that the grad discrepancy only exist when I use docker image in slurm. I have two ways to run the experiments:

  1. With conda env and without docker, works fine
# branch installed in mamba_env with project.toml removed 
> srun -A $ACCOUNT -t 00:05:00 --partition $PARTITION --job-name job_name --gpus 1 bash -c 'source activate mamba_env; python tests/ops/test_mamba_cu_seqlens_equivalence.py'

# Results
Generate random cu_seqlens = [0, 122, 132, 373, 545, 620, 958, 966, 1024]
max diff for output in varlen_mamba fwd pass: 6.556510925292969e-07
mean diff for output in varlen_mamba fwd pass: 3.8284465375681975e-08
max diff for A_log in varlen_mamba bwd pass: 2.3748725652694702e-08
mean diff for A_log in varlen_mamba bwd pass: 6.788294371062875e-10
max diff for D in varlen_mamba bwd pass: 4.291534423828125e-06
mean diff for D in varlen_mamba bwd pass: 6.210023002495291e-07
max diff for in_proj.weight in varlen_mamba bwd pass: 1.8596649169921875e-05
mean diff for in_proj.weight in varlen_mamba bwd pass: 1.1082604487455683e-06
max diff for conv1d.weight in varlen_mamba bwd pass: 5.125999450683594e-06
mean diff for conv1d.weight in varlen_mamba bwd pass: 8.711153896001633e-07
max diff for conv1d.bias in varlen_mamba bwd pass: 7.808208465576172e-06
mean diff for conv1d.bias in varlen_mamba bwd pass: 1.5153941603784915e-06
max diff for x_proj.weight in varlen_mamba bwd pass: 2.637505531311035e-06
mean diff for x_proj.weight in varlen_mamba bwd pass: 3.0485256985457454e-08
max diff for dt_proj.weight in varlen_mamba bwd pass: 8.381903171539307e-09
mean diff for dt_proj.weight in varlen_mamba bwd pass: 3.972099316129629e-10
max diff for dt_proj.bias in varlen_mamba bwd pass: 2.514570951461792e-08
mean diff for dt_proj.bias in varlen_mamba bwd pass: 2.534548571020423e-09
max diff for out_proj.weight in varlen_mamba bwd pass: 5.7220458984375e-06
mean diff for out_proj.weight in varlen_mamba bwd pass: 2.7322744244884234e-07
  1. With docker image and varlen_mamba installed in system python. Not working:
# branch installed for /usr/bin/python
> srun -A $ACCOUNT -t 00:05:00 --partition $PARTITION --job-name job_name --container-image=/path_to_image \
--gpus 1 bash -c 'python /path_to/tests/ops/test_mamba_cu_seqlens_equivalence.py'

# Results
Generate random cu_seqlens = [0, 81, 82, 139, 328, 377, 569, 724, 1024]
max diff for output in varlen_mamba fwd pass: 5.111098289489746e-06
mean diff for output in varlen_mamba fwd pass: 3.1440134762306116e-07
max diff for A_log in varlen_mamba bwd pass: 2.738088369369507e-07
mean diff for A_log in varlen_mamba bwd pass: 6.286445142222874e-09
max diff for D in varlen_mamba bwd pass: 2.002716064453125e-05
mean diff for D in varlen_mamba bwd pass: 2.825587216648273e-06
max diff for in_proj.weight in varlen_mamba bwd pass: 0.0038983821868896484
mean diff for in_proj.weight in varlen_mamba bwd pass: 4.323392204241827e-05
Traceback (most recent call last):
  File "/path_to/tests/ops/test_mamba_cu_seqlens_equivalence.py", line 125, in <module>
    main()
  File "/path_to/tests/ops/test_mamba_cu_seqlens_equivalence.py", line 122, in main
    assert torch.allclose(mamba_grad[name], mamba_ref_grad[name], rtol=rtol, atol=atol)
AssertionError

Thank you for your help again! I think the problem is not in the implementation then. I will use conda without docker for now.

Very glad to see it is helpful to you!

You are right. I guess there might be some conflicts when you try to install packages with /usr/bin/python environment in NVIDIA docker. So, it is always better to start a fresh new virtual environment even in a docker.

@bali-eng
Copy link

bali-eng commented Sep 4, 2024

Hi @zigzagcai
Thanks for the cool PR.

Here is how I install dependencies, which might be useful for those working with CUDA 12.5:

`conda create -n your_env_name python=3.10.13

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

pip install -r requirements.txt

git clone [email protected]:hustvl/Vim.git

pip install -e causal-conv1d>=1.1.0

pip install -e mamba-1p1p1

pip install --upgrade huggingface-hub==0.24.0`

I made a slight adjustment to your example, and here is the revised version:

`from collections import Counter
from itertools import chain

import torch
from mamba_ssm.modules.mamba_simple import Mamba
from torch.nn.utils.rnn import pad_sequence

sentences = [
"Apples.",
"The dog barked.",
"She smiled warmly at him.",
"The sun set behind the mountains.",
]

word_counter = Counter(chain(*[sentence.lower().split() for sentence in sentences]))
vocab = {word: i + 1 for i, (word, _) in enumerate(word_counter.most_common())}
sequences = [
[vocab[word] for word in sentence.lower().split()] for sentence in sentences
]
padded_sequences = pad_sequence(
[torch.tensor(seq) for seq in sequences], batch_first=True, padding_value=-500
)

def variable_length_sequences(new_tensor):
new_tensor_reeshaped = new_tensor.reshape(-1, 1).squeeze(1)
new_tensor_reeshaped_index = [
idx for idx, i in enumerate(new_tensor_reeshaped) if i != -500
]
start_indexes = []
last_index = None
for idx in new_tensor_reeshaped_index:
if last_index is None or idx != last_index + 1:
start_indexes.append(idx)
last_index = idx
return torch.tensor(start_indexes), torch.tensor(new_tensor_reeshaped)

def unpack(packed_hidden_states, cu_seqlens):
batch_size = cu_seqlens.shape[0] - 1
seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
hidden_dim = packed_hidden_states.shape[2]
hidden_states = torch.zeros(
batch_size,
seq_len,
hidden_dim,
dtype=packed_hidden_states.dtype,
device=packed_hidden_states.device,
)
for i in range(batch_size):
hidden_states[i, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[
:, cu_seqlens[i] : cu_seqlens[i + 1], :
]
return hidden_states

def pack(hidden_states, cu_seqlens):
batch_size, seq_len, hidden_dim = hidden_states.shape
seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1]
seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2)
indices_3d = (
torch.arange(seq_len, device=hidden_states.device)
.unsqueeze(0)
.unsqueeze(2)
.repeat(batch_size, 1, hidden_dim)
)
mask_3d = indices_3d < seq_len_list_3d
packed_hidden_states = hidden_states[mask_3d].view(-1, hidden_dim)
return packed_hidden_states

hidden_dim = 256
seq_len = 1024
batch_size = 8
device = "cuda"
mamba = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=hidden_dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to(device)

new_tensor_reeshaped_index = variable_length_sequences(padded_sequences)
hidden_states_list = [
torch.randn(l, hidden_dim, device=device)
for l in (
new_tensor_reeshaped_index[0][1:] - new_tensor_reeshaped_index[0][:-1]
).tolist()
]
packed_hidden_states = torch.cat(hidden_states_list, dim=0).unsqueeze(0)
hidden_states = unpack(packed_hidden_states, new_tensor_reeshaped_index[0])

out_ref = mamba(hidden_states)
out_ref = pack(out_ref, new_tensor_reeshaped_index[0].to("cuda")).unsqueeze(0)
out = mamba(packed_hidden_states, new_tensor_reeshaped_index[0].to("cuda"))
unpack(out, new_tensor_reeshaped_index[0]).shape`

I noticed that when processing 4 sentences, you receive embeddings for only 3 sentences (torch.Size([3, 6, 256])). It might be helpful to append last_index + 1 to the list in your variable_length_sequences function (i.e., start_indexes.append(last_index + 1)). This adjustment should ensure that the number of output sentences matches the number of input sentences (torch.Size([4, 6, 256])).

I am receiving embeddings with a shape of torch.Size([4, 6, 256]). However, one of my sentences contains only three words. Should I apply masking to the returned sequences to remove embeddings that might not be meaningful?

Thanks,

@zongtianhu
Copy link

Hi,

image

To give a simple example. What we originally pass into the original mamba block is an input with shape (batch_size=7, seq_len=10, hidden_dim) Through this PR, we can instead pass into the variable-length mamba block with an input with shape (batch_size=1, seq_len=32, hidden_dim), where the original variable-length sequences are packed into one fixed-length sequence, with an additional parameter cu_seqlens to mark sequence boundaries.

From the above figure, we can clearly see that through this PR, mamba block can focus computing resources on variable-length sequences and avoid the overhead of meaningless padding tokens.

Variable-length training is very useful for optimizing the hardware utilization during training, and we know that the well-known flash attention has supported variable-length training via cu_seqlens. Therefore, we believe that mamba, as a competitor of transformer, can improve its hardware utilization during training on real world datasets (the length distribution varies much between data samples) through this PR!

image

To give a simple example. What we originally pass into the original mamba block is an input with shape (batch_size=7, seq_len=10, hidden_dim) Through this PR, we can instead pass into the variable-length mamba block with an input with shape (batch_size=1, seq_len=32, hidden_dim), where the original variable-length sequences are packed into one fixed-length sequence, with an additional parameter cu_seqlens to mark sequence boundaries.

From the above figure, we can clearly see that through this PR, mamba block can focus computing resources on variable-length sequences and avoid the overhead of meaningless padding tokens.

Variable-length training is very useful for optimizing the hardware utilization during training, and we know that the well-known flash attention has supported variable-length training via cu_seqlens. Therefore, we believe that mamba, as a competitor of transformer, can improve its hardware utilization during training on real world datasets (the length distribution varies much between data samples) through this PR!

Thank you very much for your code and illustrations, but I have some doubts about the parameters seqlen and seq_idx in Mamba2 in the following figure. Could you provide the corresponding illustration for these parameters?
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.