From 8ffd905c91d207f5c0cc84fc2a2fb748655094f0 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 3 Jul 2024 15:55:39 -0700 Subject: [PATCH] Fix varlen generation by passing seq_idx to causal_conv1d --- mamba_ssm/__init__.py | 2 +- mamba_ssm/modules/mamba2.py | 2 ++ tests/test_generation.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py index b1b96a32..673ee32a 100644 --- a/mamba_ssm/__init__.py +++ b/mamba_ssm/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.2.1" +__version__ = "2.2.2" from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.modules.mamba_simple import Mamba diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 854ad0a8..85fd6dec 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -226,6 +226,7 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param conv_state.copy_(conv_varlen_states) assert self.activation in ["silu", "swish"] if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + assert seq_idx is None, "varlen conv1d requires the causal_conv1d package" xBC = self.act( self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):] ) # (B, L, self.d_ssm + 2 * ngroups * d_state) @@ -235,6 +236,7 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, activation=self.activation, + seq_idx=seq_idx, ).transpose(1, 2) x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1) y = mamba_chunk_scan_combined( diff --git a/tests/test_generation.py b/tests/test_generation.py index 16949023..77e1aedf 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -110,4 +110,4 @@ def test_generation_varlen(): sequences.append(sampled_tokens) out_varlen = torch.cat(scores, dim=1) print(f"Max diff: {(out_varlen - out_ref).abs().max()}") - assert (out_varlen - out_ref).abs().max() < 5 * (out_loop - out_ref).abs().max() + assert (out_varlen - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max()