diff --git a/tests/models/decoder_only/language/test_bamba.py b/tests/models/decoder_only/language/test_bamba.py
index a3bcb644baf8b..d266135360563 100644
--- a/tests/models/decoder_only/language/test_bamba.py
+++ b/tests/models/decoder_only/language/test_bamba.py
@@ -1,6 +1,6 @@
 """Compare the outputs of HF and vLLM when using greedy sampling for Mamba.
 
-This actually is really indentical to test_mamba, so maybe we can reuse
+This actually is really identical to test_mamba, so maybe we can reuse
 
 Run `pytest tests/models/decoder_only/language/test_bamba.py`.
 """
@@ -97,6 +97,7 @@ def test_batching(
         name_1="batched_vllm",
     )
 
+
 @pytest.mark.skip("bamba does not support chunked prefill yet")
 @pytest.mark.parametrize("model", MODELS)
 @pytest.mark.parametrize("dtype", ["float"])
@@ -122,6 +123,7 @@ def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts,
     ) as vllm_model:
         vllm_model.generate(example_prompts, sampling_params)
 
+
 @pytest.mark.skip("bamba does not support chunked prefill yet")
 @pytest.mark.parametrize("model", MODELS)
 @pytest.mark.parametrize("dtype", ["float"])
@@ -205,7 +207,8 @@ def test_mamba_cache_cg_padding(
     # This test is for verifying that mamba cache is padded to CG captured
     # batch size. If it's not, a torch RuntimeError will be raised because
     # tensor dimensions aren't compatible
-    while len(example_prompts) == VllmConfig.get_graph_batch_size(len(example_prompts)):
+    while len(example_prompts) == VllmConfig.get_graph_batch_size(
+            len(example_prompts)):
         example_prompts.append(example_prompts[0])
 
     try:
diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py
index b2a4b2aaefc78..150ee86b4ca3b 100644
--- a/vllm/model_executor/layers/mamba/mamba_mixer2.py
+++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py
@@ -1,36 +1,35 @@
+from typing import List, Optional, Tuple, Union
+
 import torch
 from torch import nn
-from torch.nn.parameter import Parameter
-
-# Added by the IBM Team, 2024
 
 from vllm.attention.backends.abstract import AttentionMetadata
+from vllm.distributed import (divide, get_tensor_model_parallel_rank,
+                              get_tensor_model_parallel_world_size,
+                              tensor_model_parallel_all_reduce)
 from vllm.model_executor.custom_op import CustomOp
 from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                                RowParallelLinear)
-
-from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
     causal_conv1d_fn, causal_conv1d_update)
 from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
     selective_state_update)
 from vllm.model_executor.layers.mamba.ops.ssd_combined import (
     mamba_chunk_scan_combined)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.model_loader.weight_utils import (
+    LoaderFunction, composed_weight_loader, sharded_weight_loader)
 from vllm.model_executor.models.mamba_cache import MambaCacheParams
 from vllm.model_executor.utils import set_weight_attrs
-from vllm.distributed import (divide, get_tensor_model_parallel_world_size,
-                              get_tensor_model_parallel_rank,
-                              tensor_model_parallel_all_reduce)
-from vllm.model_executor.model_loader.weight_utils import (
-    composed_weight_loader, sharded_weight_loader, LoaderFunction)
 
-from typing import Tuple, Union, Optional, List
-from vllm.model_executor.custom_op import CustomOp
+# Added by the IBM Team, 2024
+
 
 # Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
 # also referenced https://github.com/vllm-project/vllm/pull/9292
 @CustomOp.register("mixer2_gated_rms_norm")
 class Mixer2RMSNormGated(CustomOp):
+
     def __init__(self, hidden_size, eps=1e-6):
         super().__init__()
         self.hidden_size = hidden_size
@@ -84,6 +83,7 @@ def forward_cuda(
         )
         return out
 
+
 def extra_groups_for_head_shards(ngroups: int, tp_size: int):
     """Compute the extra (logical) groups to account for head shards"""
 
@@ -93,12 +93,16 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
 
     return tp_size - ngroups % tp_size
 
+
 def mamba_v2_sharded_weight_loader(
-    shard_spec: List[int], tp_size: int, tp_rank: int,
+    shard_spec: List[Tuple[int, int, float]],
+    tp_size: int,
+    tp_rank: int,
 ) -> LoaderFunction:
-    """Create a weight loader for mamba v2. This ensures that the projections are
-    correctly sharded so that they can be split into x, B, C. It also ensures the 
-    the all the groups corresponding to a head shard is placed together with it.
+    """Create a weight loader for mamba v2. This ensures that the projections 
+    are correctly sharded so that they can be split into x, B, C. It also 
+    ensures the the all the groups corresponding to a head shard is placed 
+    together with it.
     """
 
     def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
@@ -116,18 +120,21 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
             rank = tp_rank // ratio
 
             # - should start from here (determined by rank)
-            loaded_skip = rank * shard_size # take these number dims from loaded
+            # - take these number dims from loaded
+            loaded_skip = rank * shard_size
             loaded_start_idx = loaded_boundary + loaded_skip
 
             # - these many number dims to take from loaded_weight
             take = min(shard_size, full_dim - extra - loaded_skip)
 
             # - always shard on dim 0
-            param.data[
-                boundary:boundary+take,...
-            ] = loaded_weight[
-                loaded_start_idx:loaded_start_idx+take
-            ]
+            # - the ignore is for a mundane mypy error as it does not
+            #   seem to handle slices well.
+            # https://github.com/python/mypy/issues/2410
+            param.data[boundary:(boundary + take),  # type: ignore[misc]
+                       ...] = loaded_weight[
+                           loaded_start_idx:(  # type: ignore[misc]
+                               loaded_start_idx + take)]  # type: ignore[misc]
 
             # move boundaries
             boundary += shard_size
@@ -135,8 +142,9 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
 
     return loader
 
+
 # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
-@CustomOp.register("mamba_mixer2") 
+@CustomOp.register("mamba_mixer2")
 class MambaMixer2(CustomOp):
     """
     Compute ∆, A, B, C, and D the state space parameters and compute
@@ -165,17 +173,17 @@ def __init__(self,
         super().__init__()
 
         # For TP, the sharding plan is as follows:
-        # - for the conv modules, since 
+        # - for the conv modules, since
         #   conv_dim = intermediate_size * 2 * n_groups * ssm_state_size,
         #   we shard intermediate_size and n_groups
         # - since intermediate_size = n_heads * head_dim, sharding on
         #   intermediate_size is achieved by sharding on n_heads.
-        # - so if world_size divides groups, then sharding 
+        # - so if world_size divides groups, then sharding
         #   (n_groups / world_size, n_heads / world_size)
         #   also maintains the invariant n_heads % n_groups == 0
-        # - HOWEVER< if world_size DOES NOT divide groups, then we need to allocate
-        #   extra space in the shard, such that the WHOLE GROUP must be placed
-        #   together with the HEAD SHARD.
+        # - HOWEVER< if world_size DOES NOT divide groups, then we need
+        #   to allocate extra space in the shard, such that the WHOLE GROUP
+        #   must be placed together with the HEAD SHARD.
         self.tp_size = get_tensor_model_parallel_world_size()
         tp_rank = get_tensor_model_parallel_rank()
 
@@ -190,14 +198,14 @@ def __init__(self,
 
         self.n_groups = n_groups
         if n_groups % self.tp_size != 0:
-            # - for TP we shard conv_dim by sharding on n_groups, 
-            # - but if n_groups cannot divide tp_size, we need to 
+            # - for TP we shard conv_dim by sharding on n_groups,
+            # - but if n_groups cannot divide tp_size, we need to
             #   extend some extra groups
-            self.n_groups = n_groups + extra_groups_for_head_shards(n_groups, self.tp_size)
+            self.n_groups = n_groups + extra_groups_for_head_shards(
+                n_groups, self.tp_size)
 
-        self.conv_dim = (
-            intermediate_size + 2 * self.n_groups * ssm_state_size
-        )
+        self.conv_dim = (intermediate_size +
+                         2 * self.n_groups * ssm_state_size)
         self.conv1d = ColumnParallelLinear(
             input_size=conv_kernel_size,
             output_size=self.conv_dim,
@@ -210,62 +218,76 @@ def __init__(self,
         # doesn't allow to override it
         self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
 
-        self.in_proj = ColumnParallelLinear(
-            input_size=hidden_size,
-            output_size=intermediate_size + self.conv_dim + self.num_heads,
-            bias=use_bias,
-            quant_config=quant_config)
+        self.in_proj = ColumnParallelLinear(input_size=hidden_size,
+                                            output_size=intermediate_size +
+                                            self.conv_dim + self.num_heads,
+                                            bias=use_bias,
+                                            quant_config=quant_config)
 
-        # - because in_proj is a concatenation of 3 weights, we 
+        # - because in_proj is a concatenation of 3 weights, we
         #   need to interleave them before sharding
         # - use the custom weight loader mamba_v2_sharded_weight_loader
         #   for conv1d.bias, covn1d.weight and in_proj.weight
         # - need to set these settings, to assign the groups to the head shards
         group_shard_settings = (
-            self.n_groups * self.ssm_state_size, # expected model size
-            (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
-            self.num_heads // n_groups, # ratio for mapping back to original group
+            self.n_groups * self.ssm_state_size,  # expected model size
+            (self.n_groups - n_groups) *
+            self.ssm_state_size,  # extra dims assigned
+            self.num_heads //
+            n_groups,  # ratio for mapping back to original group
         )
         intemediate_settings = (intermediate_size, 0, 1)
         head_setings = (self.num_heads, 0, 1)
 
         delattr(self.conv1d.bias, "weight_loader")
-        set_weight_attrs(self.conv1d.bias, {
-            "weight_loader": mamba_v2_sharded_weight_loader(
-                [
-                    intemediate_settings, group_shard_settings, group_shard_settings,
-                ],
-                self.tp_size, tp_rank,
-            )
-        })
+        set_weight_attrs(
+            self.conv1d.bias, {
+                "weight_loader":
+                mamba_v2_sharded_weight_loader(
+                    [
+                        intemediate_settings,
+                        group_shard_settings,
+                        group_shard_settings,
+                    ],
+                    self.tp_size,
+                    tp_rank,
+                )
+            })
 
         delattr(self.conv1d.weight, "weight_loader")
-        set_weight_attrs(self.conv1d.weight, {
-            "weight_loader": mamba_v2_sharded_weight_loader(
-                [
-                    intemediate_settings, group_shard_settings, group_shard_settings,
-                ],
-                self.tp_size, tp_rank
-            )
-        })
+        set_weight_attrs(
+            self.conv1d.weight, {
+                "weight_loader":
+                mamba_v2_sharded_weight_loader([
+                    intemediate_settings,
+                    group_shard_settings,
+                    group_shard_settings,
+                ], self.tp_size, tp_rank)
+            })
 
         delattr(self.in_proj.weight, "weight_loader")
-        set_weight_attrs(self.in_proj.weight, {
-            "weight_loader": mamba_v2_sharded_weight_loader(
-                [
-                    intemediate_settings, # for gate
-                    intemediate_settings, group_shard_settings, group_shard_settings,
-                    head_setings,  # for dt
-                ],
-                self.tp_size, tp_rank
-            )
-        })
-
-        # - these are TPed by heads to reduce the size of the 
+        set_weight_attrs(
+            self.in_proj.weight,
+            {
+                "weight_loader":
+                mamba_v2_sharded_weight_loader(
+                    [
+                        intemediate_settings,  # for gate
+                        intemediate_settings,
+                        group_shard_settings,
+                        group_shard_settings,
+                        head_setings,  # for dt
+                    ],
+                    self.tp_size,
+                    tp_rank)
+            })
+
+        # - these are TPed by heads to reduce the size of the
         #   temporal shape
         self.A = nn.Parameter(
             torch.empty(
-                divide(num_heads, self.tp_size), dtype=torch.float32,
+                divide(num_heads, self.tp_size),
+                dtype=torch.float32,
             ))
         self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
         self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
@@ -277,16 +299,14 @@ def __init__(self,
         set_weight_attrs(self.dt_bias,
                          {"weight_loader": sharded_weight_loader(0)})
 
-        self.out_proj = RowParallelLinear(
-            intermediate_size,
-            hidden_size,
-            bias=use_bias,
-            input_is_parallel=True,
-            quant_config=quant_config)
+        self.out_proj = RowParallelLinear(intermediate_size,
+                                          hidden_size,
+                                          bias=use_bias,
+                                          input_is_parallel=True,
+                                          quant_config=quant_config)
 
-        self.norm = Mixer2RMSNormGated(
-            intermediate_size // self.tp_size, eps=rms_norm_eps
-        )
+        self.norm = Mixer2RMSNormGated(intermediate_size // self.tp_size,
+                                       eps=rms_norm_eps)
 
     def forward_native(self, hidden_states: torch.Tensor,
                        attn_metadata: AttentionMetadata,
@@ -297,27 +317,27 @@ def forward_cuda(self, hidden_states: torch.Tensor,
                      attn_metadata: AttentionMetadata,
                      mamba_cache_params: MambaCacheParams):
 
-
         seq_len, _ = hidden_states.shape
         groups_time_state_size = self.n_groups * self.ssm_state_size
 
         # - doing it differently from mixer v1; little confused with its logic
-        # - we need to do is to detect if there is any prefill; if there are 
+        # - we need to do is to detect if there is any prefill; if there are
         #   no prefils, then each example will be coming in one sample at a time
-        # - on the other hand v1 checks for "query_start_loc" and "context_lens_tensor"
-        #   however we have noticed that, even when the samples are coming in
-        #   one at a time, they are still non-NO.e
+        # - on the other hand v1 checks for "query_start_loc"
+        #   and "context_lens_tensor" however we have noticed that, even
+        #   when the samples are coming in
+        #   one at a time, they are still not NONE, e.g.,
         #   * "query_start_loc" = [0, 1, ..]
         #   * "context_lens_tensor" = [8, ...]
-        has_prefill = attn_metadata.num_prefills > 0 
+        has_prefill = attn_metadata.num_prefills > 0
 
         # 1. Gated MLP's linear projection
         projected_states, _ = self.in_proj(hidden_states)
         gate, hidden_states_B_C, dt = torch.split(
             projected_states,
             [
-                self.intermediate_size // self.tp_size, 
-                self.conv_dim // self.tp_size, 
+                self.intermediate_size // self.tp_size,
+                self.conv_dim // self.tp_size,
                 self.num_heads // self.tp_size,
             ],
             dim=-1,
@@ -335,7 +355,7 @@ def forward_cuda(self, hidden_states: torch.Tensor,
             # |-------------------- seq_len ---------------------|
             #                                   |-- query_len ---|
 
-            # - "cache_indices" upates the conv_state cache in positions
+            # - "cache_indices" updates the conv_state cache in positions
             #   pointed to by "mamba_cache_params.state_indices_tensor"
             hidden_states_B_C = causal_conv1d_fn(
                 hidden_states_B_C.transpose(0, 1),
@@ -345,8 +365,8 @@ def forward_cuda(self, hidden_states: torch.Tensor,
                 conv_states=mamba_cache_params.conv_state,
                 has_initial_state=attn_metadata.context_lens_tensor > 0,
                 cache_indices=mamba_cache_params.state_indices_tensor,
-                query_start_loc=attn_metadata.query_start_loc
-            ).transpose(0, 1)[:seq_len]
+                query_start_loc=attn_metadata.query_start_loc).transpose(
+                    0, 1)[:seq_len]
         else:
             hidden_states_B_C = causal_conv1d_update(
                 hidden_states_B_C,
@@ -354,14 +374,13 @@ def forward_cuda(self, hidden_states: torch.Tensor,
                 conv_weights,
                 self.conv1d.bias,
                 self.activation,
-                conv_state_indices=mamba_cache_params.state_indices_tensor
-            )
+                conv_state_indices=mamba_cache_params.state_indices_tensor)
 
         # - get hidden_states, B and C after depthwise convolution.
         hidden_states, B, C = torch.split(
             hidden_states_B_C,
             [
-                self.intermediate_size // self.tp_size, 
+                self.intermediate_size // self.tp_size,
                 groups_time_state_size // self.tp_size,
                 groups_time_state_size // self.tp_size,
             ],
@@ -370,12 +389,12 @@ def forward_cuda(self, hidden_states: torch.Tensor,
 
         # 3. State Space Model sequence transformation
         if has_prefill:
-            
+
             # FIXME: we are having problems using mamba_chunk_scan_combined
             # with chunked prefill. This is because there is no
             # initial_states requires initial_states.shape[0] to match
             # the batch size, but cu_seqlens requires batch_size = 1.
-            # Therefore as of now, initial_states and cu_seqlens are 
+            # Therefore as of now, initial_states and cu_seqlens are
             # mutually exclusive.
 
             initial_states = None
@@ -385,7 +404,8 @@ def forward_cuda(self, hidden_states: torch.Tensor,
             #     ]
 
             scan_output, varlen_state = mamba_chunk_scan_combined(
-                hidden_states.view(1, seq_len, self.num_heads // self.tp_size, self.head_dim),
+                hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
+                                   self.head_dim),
                 dt.unsqueeze(0),
                 self.A,
                 B.view(1, seq_len, self.n_groups // self.tp_size, -1),
@@ -412,15 +432,17 @@ def forward_cuda(self, hidden_states: torch.Tensor,
             hidden_states = scan_output.view(seq_len, -1)
         else:
 
-            # NOTE: can be optimized? 
+            # NOTE: can be optimized?
             n_groups = self.n_groups // self.tp_size
-            A = self.A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
+            A = self.A[:, None, ...][:, :, None].expand(
+                -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
             dt = dt[:, :, None].expand(-1, -1, self.head_dim)
             dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
             D = self.D[:, None, ...].expand(-1, self.head_dim)
             B = B.view(-1, n_groups, B.shape[1] // n_groups)
             C = C.view(-1, n_groups, C.shape[1] // n_groups)
-            hidden_states_reshaped = hidden_states.view(-1, self.num_heads // self.tp_size, self.head_dim)
+            hidden_states_reshaped = hidden_states.view(
+                -1, self.num_heads // self.tp_size, self.head_dim)
 
             # - the hidden is reshaped into number of current batches
             # - in this case there is no more prefil, so the batches gen
@@ -434,22 +456,21 @@ def forward_cuda(self, hidden_states: torch.Tensor,
                 mamba_cache_params.ssm_state,
                 hidden_states_reshaped,
                 dt,
-                A, 
+                A,
                 B,
                 C,
-                D, 
+                D,
                 z=None,
                 dt_bias=dt_bias,
                 dt_softplus=True,
                 state_batch_indices=mamba_cache_params.state_indices_tensor,
             )
             hidden_states = hidden_states.view(
-                -1, (self.num_heads // self.tp_size) * self.head_dim
-            )
+                -1, (self.num_heads // self.tp_size) * self.head_dim)
 
         # # 4. gated MLP
         hidden_states = self.norm(hidden_states, gate)
 
         # # 5. Final linear projection
         out, _ = self.out_proj(hidden_states)
-        return out 
\ No newline at end of file
+        return out
diff --git a/vllm/model_executor/layers/mamba/ops/softplus.py b/vllm/model_executor/layers/mamba/ops/softplus.py
index 5541655c66160..5ec75be51bf3b 100644
--- a/vllm/model_executor/layers/mamba/ops/softplus.py
+++ b/vllm/model_executor/layers/mamba/ops/softplus.py
@@ -1,15 +1,21 @@
+# Copyright (c) 2024, Tri Dao, Albert Gu.
+# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/softplus.py
+
+# ruff: noqa: E501
+
 import triton
 import triton.language as tl
 from packaging import version
 
 TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
 
-
 if TRITON3:
+
     @triton.jit
     def softplus(dt):
         return tl.math.log(tl.math.exp(dt) + 1)
 else:
+
     @triton.jit
     def softplus(dt):
-        return tl.math.log1p(tl.exp(dt))
\ No newline at end of file
+        return tl.math.log1p(tl.exp(dt))
diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py
index 312a65769b634..3eba3c49b4590 100644
--- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py
+++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py
@@ -1,51 +1,134 @@
 # Copyright (c) 2024, Tri Dao, Albert Gu.
+# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_bmm.py
 
+# ruff: noqa: E501,SIM102
 """We want triton==2.1.0 or 2.2.0 for this
 """
 
 import math
-import torch
-import torch.nn.functional as F
 
+import torch
 import triton
 import triton.language as tl
 
-from einops import rearrange, repeat
-
 
 def init_to_zero(names):
-    return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
+    return lambda nargs: [
+        nargs[name].zero_() for name in names if nargs[name] is not None
+    ]
 
 
 @triton.autotune(
     configs=[
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
-        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 256,
+                'BLOCK_SIZE_K': 64
+            },
+            num_stages=3,
+            num_warps=8),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 256,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 128,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 128,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 32,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 32,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=5,
+            num_warps=2),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 32,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=5,
+            num_warps=2),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=2),
     ],
     key=['chunk_size', 'K', 'IS_CAUSAL'],
 )
 @triton.jit
 def _bmm_chunk_fwd_kernel(
     # Pointers to matrices
-    a_ptr, b_ptr, out_ptr, seq_idx_ptr,
+    a_ptr,
+    b_ptr,
+    out_ptr,
+    seq_idx_ptr,
     # Matrix dimensions
-    seqlen, chunk_size, K, ngroups,
-    stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
-    stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,
-    stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,
-    stride_seq_idx_batch, stride_seq_idx_seqlen,
+    seqlen,
+    chunk_size,
+    K,
+    ngroups,
+    stride_a_batch,
+    stride_a_seqlen,
+    stride_a_head,
+    stride_ak,
+    stride_b_batch,
+    stride_b_seqlen,
+    stride_b_head,
+    stride_bk,
+    stride_out_batch,
+    stride_out_chunk,
+    stride_out_head,
+    stride_outm,
+    stride_outn,
+    stride_seq_idx_batch,
+    stride_seq_idx_seqlen,
     # Meta-parameters
     IS_CAUSAL: tl.constexpr,
     dot_dtype: tl.constexpr,
     HAS_SEQ_IDX: tl.constexpr,
-    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
+    BLOCK_SIZE_M: tl.constexpr,
+    BLOCK_SIZE_N: tl.constexpr,
+    BLOCK_SIZE_K: tl.constexpr,
 ):
     pid_b = tl.program_id(axis=1)
     pid_ch = tl.program_id(axis=2).to(tl.int64)
@@ -65,14 +148,22 @@ def _bmm_chunk_fwd_kernel(
     offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
     offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
     offs_k = tl.arange(0, BLOCK_SIZE_K)
-    a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
-    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
+    a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen +
+                      offs_k[None, :] * stride_ak)
+    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
+                      offs_n[None, :] * stride_b_seqlen)
     chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
 
     acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
     for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
-        a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)
-        b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)
+        a = tl.load(a_ptrs,
+                    mask=(offs_m[:, None] < chunk_size_limit) &
+                    (offs_k[None, :] < K - k * BLOCK_SIZE_K),
+                    other=0.0).to(dot_dtype)
+        b = tl.load(b_ptrs,
+                    mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) &
+                    (offs_n[None, :] < chunk_size_limit),
+                    other=0.0).to(dot_dtype)
         acc += tl.dot(a, b)
         a_ptrs += BLOCK_SIZE_K * stride_ak
         b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -81,16 +172,30 @@ def _bmm_chunk_fwd_kernel(
     offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
     if HAS_SEQ_IDX:
         chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
-        seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
-        seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
+        seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
+                            mask=offs_m < chunk_size_limit,
+                            other=-1)
+        seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
+                            mask=offs_n < chunk_size_limit,
+                            other=-2)
         acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
     out = acc.to(out_ptr.dtype.element_ty)
 
     out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
-    out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
-    tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
+    out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
+                          offs_n[None, :] * stride_outn)
+    tl.store(out_ptrs,
+             out,
+             mask=(offs_m[:, None] < chunk_size) &
+             (offs_n[None, :] < chunk_size))
+
 
-def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
+def _bmm_chunk_fwd(a,
+                   b,
+                   chunk_size,
+                   seq_idx=None,
+                   causal=False,
+                   output_dtype=None):
     """
     Argument:
         a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
@@ -117,20 +222,44 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=No
     nchunks = math.ceil(seqlen / chunk_size)
     # Allocates output.
     out_dtype = a.dtype if output_dtype is None else output_dtype
-    out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),
-                      device=a.device, dtype=out_dtype)
-    dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
-                 (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))
-    grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
-                    batch, nchunks if not has_groups else nchunks * ngroups)
+    out = torch.empty(
+        (batch, nchunks, chunk_size, chunk_size) if not has_groups else
+        (batch, nchunks, ngroups, chunk_size, chunk_size),
+        device=a.device,
+        dtype=out_dtype)
+    dot_dtype = (tl.bfloat16
+                 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
+                 (tl.float16 if a.dtype == torch.float16
+                  or b.dtype == torch.float16 else tl.float32))
+    grid = lambda META: (triton.cdiv(
+        chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
+            chunk_size, META['BLOCK_SIZE_N']), batch, nchunks
+                         if not has_groups else nchunks * ngroups)
     with torch.cuda.device(a.device.index):
         _bmm_chunk_fwd_kernel[grid](
-            a, b, out, seq_idx,
-            seqlen, chunk_size, k, ngroups if has_groups else 1,
-            a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
-            b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),
-            out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),
-            *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
+            a,
+            b,
+            out,
+            seq_idx,
+            seqlen,
+            chunk_size,
+            k,
+            ngroups if has_groups else 1,
+            a.stride(0),
+            a.stride(1),
+            0 if not has_groups else a.stride(2),
+            a.stride(-1),
+            b.stride(0),
+            b.stride(1),
+            0 if not has_groups else b.stride(2),
+            b.stride(-1),
+            out.stride(0),
+            out.stride(1),
+            0 if not has_groups else out.stride(2),
+            out.stride(-2),
+            out.stride(-1),
+            *((seq_idx.stride(0),
+               seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
             causal,
             dot_dtype,
             HAS_SEQ_IDX=seq_idx is not None,
diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
index 79fa52e0b8c4f..c538aaa464171 100644
--- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
+++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
@@ -1,55 +1,175 @@
 # Copyright (c) 2024, Tri Dao, Albert Gu.
+# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_chunk_scan.py
 
+# ruff: noqa: E501
 """We want triton==2.1.0 or 2.2.0 for this
 """
 
-from packaging import version
-
 import torch
-
 import triton
 import triton.language as tl
+from packaging import version
 
 TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
 
 
 def init_to_zero(names):
-    return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
+    return lambda nargs: [
+        nargs[name].zero_() for name in names if nargs[name] is not None
+    ]
 
 
 @triton.autotune(
     configs=[
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
-        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 256,
+                'BLOCK_SIZE_K': 64
+            },
+            num_stages=3,
+            num_warps=8),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 256,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 128,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 128,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 64
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 128,
+                'BLOCK_SIZE_K': 64
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 32,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 32,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=5,
+            num_warps=2),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 32,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=5,
+            num_warps=2),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=2),
     ],
     key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],
 )
 @triton.jit
 def _chunk_scan_fwd_kernel(
     # Pointers to matrices
-    cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr,
+    cb_ptr,
+    x_ptr,
+    z_ptr,
+    out_ptr,
+    out_x_ptr,
+    dt_ptr,
+    dA_cumsum_ptr,
+    seq_idx_ptr,
+    C_ptr,
+    prev_states_ptr,
+    D_ptr,
     # Matrix dimensions
-    chunk_size, hdim, dstate,
-    batch, seqlen, nheads_ngroups_ratio,
+    chunk_size,
+    hdim,
+    dstate,
+    batch,
+    seqlen,
+    nheads_ngroups_ratio,
     # Strides
-    stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
-    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
-    stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,
-    stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
-    stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
-    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
-    stride_seq_idx_batch, stride_seq_idx_seqlen,
-    stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
-    stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
+    stride_cb_batch,
+    stride_cb_chunk,
+    stride_cb_head,
+    stride_cb_csize_m,
+    stride_cb_csize_k,
+    stride_x_batch,
+    stride_x_seqlen,
+    stride_x_head,
+    stride_x_hdim,
+    stride_z_batch,
+    stride_z_seqlen,
+    stride_z_head,
+    stride_z_hdim,
+    stride_out_batch,
+    stride_out_seqlen,
+    stride_out_head,
+    stride_out_hdim,
+    stride_dt_batch,
+    stride_dt_chunk,
+    stride_dt_head,
+    stride_dt_csize,
+    stride_dA_cs_batch,
+    stride_dA_cs_chunk,
+    stride_dA_cs_head,
+    stride_dA_cs_csize,
+    stride_seq_idx_batch,
+    stride_seq_idx_seqlen,
+    stride_C_batch,
+    stride_C_seqlen,
+    stride_C_head,
+    stride_C_dstate,
+    stride_states_batch,
+    stride_states_chunk,
+    stride_states_head,
+    stride_states_hdim,
+    stride_states_dstate,
     stride_D_head,
     # Meta-parameters
     IS_CAUSAL: tl.constexpr,
@@ -57,7 +177,9 @@ def _chunk_scan_fwd_kernel(
     D_HAS_HDIM: tl.constexpr,
     HAS_Z: tl.constexpr,
     HAS_SEQ_IDX: tl.constexpr,
-    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
+    BLOCK_SIZE_M: tl.constexpr,
+    BLOCK_SIZE_N: tl.constexpr,
+    BLOCK_SIZE_K: tl.constexpr,
     BLOCK_SIZE_DSTATE: tl.constexpr,
     IS_TRITON_22: tl.constexpr,
 ):
@@ -68,23 +190,31 @@ def _chunk_scan_fwd_kernel(
     num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
     pid_m = tl.program_id(axis=0) // num_pid_n
     pid_n = tl.program_id(axis=0) % num_pid_n
-    cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
+    cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (
+        pid_h // nheads_ngroups_ratio) * stride_cb_head
     x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
     dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
     dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
-    C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head
+    C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (
+        pid_h // nheads_ngroups_ratio) * stride_C_head
     prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
     if HAS_SEQ_IDX:
         seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
 
     offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
     offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
-    dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
+    dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
+                      mask=offs_m < chunk_size,
+                      other=0.0).to(tl.float32)
 
     chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
     if HAS_SEQ_IDX:
-        seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
-        seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
+        seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen,
+                               mask=pid_c >= 1,
+                               other=0)
+        seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
+                            mask=offs_m < chunk_size_limit,
+                            other=-1)
     acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
 
     # Without the if (pid_c > -1), with Triton 2.1.0, I get
@@ -92,23 +222,40 @@ def _chunk_scan_fwd_kernel(
     # With Triton 2.2.0, this works
     if IS_TRITON_22 or pid_c > -1:
         # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
-        offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
-        C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate)
-        prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate)
+        offs_k_dstate = tl.arange(
+            0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
+        C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen +
+                          offs_k_dstate[None, :] * stride_C_dstate)
+        prev_states_ptrs = prev_states_ptr + (
+            offs_n[None, :] * stride_states_hdim +
+            offs_k_dstate[:, None] * stride_states_dstate)
         if not HAS_SEQ_IDX:
             scale_m = tl.exp(dA_cs_m)
         else:
             scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
         if BLOCK_SIZE_DSTATE <= 128:
-            C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0)
-            prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
+            C = tl.load(C_ptrs,
+                        mask=(offs_m[:, None] < chunk_size_limit) &
+                        (offs_k_dstate[None, :] < dstate),
+                        other=0.0)
+            prev_states = tl.load(prev_states_ptrs,
+                                  mask=(offs_k_dstate[:, None] < dstate) &
+                                  (offs_n[None, :] < hdim),
+                                  other=0.0)
             prev_states = prev_states.to(C_ptr.dtype.element_ty)
             acc = tl.dot(C, prev_states) * scale_m[:, None]
         else:
             for k in range(0, dstate, BLOCK_SIZE_K):
-                C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0)
+                C = tl.load(C_ptrs,
+                            mask=(offs_m[:, None] < chunk_size_limit) &
+                            (offs_k_dstate[None, :] < dstate - k),
+                            other=0.0)
                 # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
-                prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
+                prev_states = tl.load(
+                    prev_states_ptrs,
+                    mask=(offs_k_dstate[:, None] < dstate - k) &
+                    (offs_n[None, :] < hdim),
+                    other=0.0)
                 prev_states = prev_states.to(C_ptr.dtype.element_ty)
                 acc += tl.dot(C, prev_states)
                 C_ptrs += BLOCK_SIZE_K
@@ -116,24 +263,36 @@ def _chunk_scan_fwd_kernel(
             acc *= scale_m[:, None]
 
     offs_k = tl.arange(0, BLOCK_SIZE_K)
-    cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
-    x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
+    cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m +
+                        offs_k[None, :] * stride_cb_csize_k)
+    x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen +
+                      offs_n[None, :] * stride_x_hdim)
     dt_ptrs = dt_ptr + offs_k * stride_dt_csize
     dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
-    K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
+    K_MAX = chunk_size_limit if not IS_CAUSAL else min(
+        (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
     for k in range(0, K_MAX, BLOCK_SIZE_K):
-        cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32)
-        dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
+        cb = tl.load(cb_ptrs,
+                     mask=(offs_m[:, None] < chunk_size) &
+                     (offs_k[None, :] < chunk_size - k),
+                     other=0.0).to(tl.float32)
+        dA_cs_k = tl.load(dA_cumsum_ptrs,
+                          mask=offs_k < chunk_size - k,
+                          other=0.0).to(tl.float32)
         # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
         # So we don't need masking wrt seq_idx here.
-        cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
-        dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
+        cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
+        dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k,
+                       other=0.0).to(tl.float32)
         cb *= dt_k
         if IS_CAUSAL:
             mask = offs_m[:, None] >= k + offs_k[None, :]
             cb = tl.where(mask, cb, 0.0)
         cb = cb.to(x_ptr.dtype.element_ty)
-        x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0)
+        x = tl.load(x_ptrs,
+                    mask=(offs_k[:, None] < chunk_size_limit - k) &
+                    (offs_n[None, :] < hdim),
+                    other=0.0)
         acc += tl.dot(cb, x)
         cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
         x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
@@ -145,28 +304,54 @@ def _chunk_scan_fwd_kernel(
 
     if HAS_D:
         if D_HAS_HDIM:
-            D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
+            D = tl.load(D_ptr + pid_h * stride_D_head + offs_n,
+                        mask=offs_n < hdim,
+                        other=0.0).to(tl.float32)
         else:
             D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
-        x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
-                             mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
+        x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen +
+                                      offs_n[None, :] * stride_x_hdim),
+                             mask=(offs_m[:, None] < chunk_size_limit) &
+                             (offs_n[None, :] < hdim),
+                             other=0.0).to(tl.float32)
         acc += x_residual * D
 
     if HAS_Z:
         out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
-        out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :])
-        tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))
+        out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] +
+                                  offs_out_n[None, :])
+        tl.store(out_x_ptrs,
+                 acc,
+                 mask=(offs_out_m[:, None] < chunk_size_limit) &
+                 (offs_out_n[None, :] < hdim))
 
         z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head
-        z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :])
-        z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32)
+        z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] +
+                          stride_z_hdim * offs_out_n[None, :])
+        z = tl.load(z_ptrs,
+                    mask=(offs_out_m[:, None] < chunk_size_limit) &
+                    (offs_out_n[None, :] < hdim),
+                    other=0.0).to(tl.float32)
         acc *= z * tl.sigmoid(z)
 
     out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
-    out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim)
-    tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))
+    out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] +
+                          offs_out_n[None, :] * stride_out_hdim)
+    tl.store(out_ptrs,
+             acc,
+             mask=(offs_out_m[:, None] < chunk_size_limit) &
+             (offs_out_n[None, :] < hdim))
+
 
-def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None):
+def _chunk_scan_fwd(cb,
+                    x,
+                    dt,
+                    dA_cumsum,
+                    C,
+                    states,
+                    D=None,
+                    z=None,
+                    seq_idx=None):
     batch, seqlen, nheads, headdim = x.shape
     _, _, nchunks, chunk_size = dt.shape
     _, _, ngroups, dstate = C.shape
@@ -176,36 +361,88 @@ def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=Non
     if z is not None:
         assert z.shape == x.shape
     if D is not None:
-        assert D.shape == (nheads, headdim) or D.shape == (nheads,)
+        assert D.shape == (nheads, headdim) or D.shape == (nheads, )
     assert dt.shape == (batch, nheads, nchunks, chunk_size)
     assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
     assert states.shape == (batch, nchunks, nheads, headdim, dstate)
     if seq_idx is not None:
         assert seq_idx.shape == (batch, seqlen)
     # Allocates output.
-    out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
+    out = torch.empty(batch,
+                      seqlen,
+                      nheads,
+                      headdim,
+                      device=x.device,
+                      dtype=x.dtype)
     if z is not None:
-        out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
+        out_x = torch.empty(batch,
+                            seqlen,
+                            nheads,
+                            headdim,
+                            device=x.device,
+                            dtype=x.dtype)
         assert out_x.stride() == out.stride()
     else:
         out_x = None
-    grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
-                    batch * nchunks, nheads)
-    z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))
-                  if z is not None else (0, 0, 0, 0))
+    grid = lambda META: (triton.cdiv(
+        chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
+            headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads)
+    z_strides = ((z.stride(0), z.stride(1), z.stride(2),
+                  z.stride(3)) if z is not None else (0, 0, 0, 0))
     _chunk_scan_fwd_kernel[grid](
-        cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D,
-        chunk_size, headdim, dstate,
-        batch, seqlen, nheads // ngroups,
-        cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
-        x.stride(0), x.stride(1), x.stride(2), x.stride(3),
-        z_strides[0], z_strides[1], z_strides[2], z_strides[3],
-        out.stride(0), out.stride(1), out.stride(2), out.stride(3),
-        dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
-        dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
-        *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
-        C.stride(0), C.stride(1), C.stride(2), C.stride(3),
-        states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
+        cb,
+        x,
+        z,
+        out,
+        out_x,
+        dt,
+        dA_cumsum,
+        seq_idx,
+        C,
+        states,
+        D,
+        chunk_size,
+        headdim,
+        dstate,
+        batch,
+        seqlen,
+        nheads // ngroups,
+        cb.stride(0),
+        cb.stride(1),
+        cb.stride(2),
+        cb.stride(3),
+        cb.stride(4),
+        x.stride(0),
+        x.stride(1),
+        x.stride(2),
+        x.stride(3),
+        z_strides[0],
+        z_strides[1],
+        z_strides[2],
+        z_strides[3],
+        out.stride(0),
+        out.stride(1),
+        out.stride(2),
+        out.stride(3),
+        dt.stride(0),
+        dt.stride(2),
+        dt.stride(1),
+        dt.stride(3),
+        dA_cumsum.stride(0),
+        dA_cumsum.stride(2),
+        dA_cumsum.stride(1),
+        dA_cumsum.stride(3),
+        *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else
+          (0, 0)),
+        C.stride(0),
+        C.stride(1),
+        C.stride(2),
+        C.stride(3),
+        states.stride(0),
+        states.stride(1),
+        states.stride(2),
+        states.stride(3),
+        states.stride(4),
         D.stride(0) if D is not None else 0,
         True,
         D is not None,
@@ -215,4 +452,4 @@ def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=Non
         HAS_SEQ_IDX=seq_idx is not None,
         IS_TRITON_22=TRITON_22,
     )
-    return out, out_x
\ No newline at end of file
+    return out, out_x
diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
index 3184bbbf03d41..bafdcd2585e5a 100644
--- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
+++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
@@ -1,22 +1,24 @@
 # Copyright (c) 2024, Tri Dao, Albert Gu.
+# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_chunk_state.py
 
+# ruff: noqa: E501
 """We want triton==2.1.0 or 2.2.0 for this
 """
 
 import math
-import torch
-import torch.nn.functional as F
 
+import torch
 import triton
 import triton.language as tl
 
-from einops import rearrange, repeat
-
 from .softplus import softplus
 
 
 def init_to_zero(names):
-    return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
+    return lambda nargs: [
+        nargs[name].zero_() for name in names if nargs[name] is not None
+    ]
+
 
 @triton.autotune(
     configs=[
@@ -33,20 +35,37 @@ def init_to_zero(names):
 @triton.jit
 def _chunk_cumsum_fwd_kernel(
     # Pointers to matrices
-    dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,
+    dt_ptr,
+    A_ptr,
+    dt_bias_ptr,
+    dt_out_ptr,
+    dA_cumsum_ptr,
     # Matrix dimension
-    batch, seqlen, nheads, chunk_size,
-    dt_min, dt_max,
+    batch,
+    seqlen,
+    nheads,
+    chunk_size,
+    dt_min,
+    dt_max,
     # Strides
-    stride_dt_batch, stride_dt_seqlen, stride_dt_head,
+    stride_dt_batch,
+    stride_dt_seqlen,
+    stride_dt_head,
     stride_A_head,
     stride_dt_bias_head,
-    stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,
-    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
+    stride_dt_out_batch,
+    stride_dt_out_chunk,
+    stride_dt_out_head,
+    stride_dt_out_csize,
+    stride_dA_cs_batch,
+    stride_dA_cs_chunk,
+    stride_dA_cs_head,
+    stride_dA_cs_csize,
     # Meta-parameters
     DT_SOFTPLUS: tl.constexpr,
     HAS_DT_BIAS: tl.constexpr,
-    BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,
+    BLOCK_SIZE_H: tl.constexpr,
+    BLOCK_SIZE_CHUNK: tl.constexpr,
 ):
     pid_b = tl.program_id(axis=0)
 
@@ -60,60 +79,165 @@ def _chunk_cumsum_fwd_kernel(
 
     offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
     offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
-    dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)
+    dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head +
+                        offs_c[None, :] * stride_dt_seqlen)
     A_ptrs = A_ptr + offs_h * stride_A_head
-    dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)
-    dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)
+    dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head +
+                                offs_c[None, :] * stride_dt_out_csize)
+    dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head +
+                                  offs_c[None, :] * stride_dA_cs_csize)
     chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
 
-    dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
+    dt = tl.load(dt_ptrs,
+                 mask=(offs_h[:, None] < nheads) &
+                 (offs_c[None, :] < chunk_size_limit),
+                 other=0.0).to(tl.float32)
     if HAS_DT_BIAS:
-        dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
+        dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head,
+                          mask=offs_h < nheads,
+                          other=0.0).to(tl.float32)
         dt += dt_bias[:, None]
     if DT_SOFTPLUS:
         dt = tl.where(dt <= 20.0, softplus(dt), dt)
     # As of Triton 2.2.0, tl.clamp is not available yet
     # dt = tl.clamp(dt, dt_min, dt_max)
     dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
-    dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)
-    tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
+    dt = tl.where(
+        (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt,
+        0.0)
+    tl.store(dt_out_ptrs,
+             dt,
+             mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
     A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
     dA = dt * A[:, None]
     dA_cs = tl.cumsum(dA, axis=1)
-    tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
+    tl.store(dA_cs_ptrs,
+             dA_cs,
+             mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
 
 
 @triton.autotune(
     configs=[
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
-        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 256,
+                'BLOCK_SIZE_K': 64
+            },
+            num_stages=3,
+            num_warps=8),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 256,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 128,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 128,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 32,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 32,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=5,
+            num_warps=2),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 32,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=5,
+            num_warps=2),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=2),
     ],
     key=['hdim', 'dstate', 'chunk_size'],
 )
 @triton.jit
 def _chunk_state_fwd_kernel(
     # Pointers to matrices
-    x_ptr, b_ptr, states_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
+    x_ptr,
+    b_ptr,
+    states_ptr,
+    dt_ptr,
+    dA_cumsum_ptr,
+    seq_idx_ptr,
     # Matrix dimensions
-    hdim, dstate, chunk_size,
-    batch, seqlen, nheads_ngroups_ratio,
+    hdim,
+    dstate,
+    chunk_size,
+    batch,
+    seqlen,
+    nheads_ngroups_ratio,
     # Strides
-    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
-    stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
-    stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
-    stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
-    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
-    stride_seq_idx_batch, stride_seq_idx_seqlen,
+    stride_x_batch,
+    stride_x_seqlen,
+    stride_x_head,
+    stride_x_hdim,
+    stride_b_batch,
+    stride_b_seqlen,
+    stride_b_head,
+    stride_b_dstate,
+    stride_states_batch,
+    stride_states_chunk,
+    stride_states_head,
+    stride_states_hdim,
+    stride_states_dstate,
+    stride_dt_batch,
+    stride_dt_chunk,
+    stride_dt_head,
+    stride_dt_csize,
+    stride_dA_cs_batch,
+    stride_dA_cs_chunk,
+    stride_dA_cs_head,
+    stride_dA_cs_csize,
+    stride_seq_idx_batch,
+    stride_seq_idx_seqlen,
     # Meta-parameters
     HAS_SEQ_IDX: tl.constexpr,
-    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
+    BLOCK_SIZE_M: tl.constexpr,
+    BLOCK_SIZE_N: tl.constexpr,
+    BLOCK_SIZE_K: tl.constexpr,
 ):
     pid_bc = tl.program_id(axis=1).to(tl.int64)
     pid_c = pid_bc // batch
@@ -122,7 +246,8 @@ def _chunk_state_fwd_kernel(
     num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
     pid_m = tl.program_id(axis=0) // num_pid_n
     pid_n = tl.program_id(axis=0) % num_pid_n
-    b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
+    b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (
+        pid_h // nheads_ngroups_ratio) * stride_b_head
     x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
     dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
     dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
@@ -132,30 +257,46 @@ def _chunk_state_fwd_kernel(
     offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
     offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
     offs_k = tl.arange(0, BLOCK_SIZE_K)
-    x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)
-    b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)
+    x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim +
+                      offs_k[None, :] * stride_x_seqlen)
+    b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate +
+                      offs_k[:, None] * stride_b_seqlen)
     dt_ptrs = dt_ptr + offs_k * stride_dt_csize
-    dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
+    dA_cs_last = tl.load(dA_cumsum_ptr +
+                         (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
     dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
     if HAS_SEQ_IDX:
         seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
 
     chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
     if HAS_SEQ_IDX:
-        seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
+        seq_idx_last = tl.load(seq_idx_ptr +
+                               (chunk_size_limit - 1) * stride_seq_idx_seqlen)
 
     acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
     for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
-        x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0)
-        b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
-        dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
+        x = tl.load(x_ptrs,
+                    mask=(offs_m[:, None] < hdim) &
+                    (offs_k[None, :] < chunk_size_limit - k),
+                    other=0.0)
+        b = tl.load(b_ptrs,
+                    mask=(offs_k[:, None] < chunk_size_limit - k) &
+                    (offs_n[None, :] < dstate),
+                    other=0.0).to(tl.float32)
+        dA_cs_k = tl.load(dA_cumsum_ptrs,
+                          mask=offs_k < chunk_size_limit - k,
+                          other=0.0).to(tl.float32)
         if HAS_SEQ_IDX:
-            seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
-        dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
+            seq_idx_k = tl.load(seq_idx_ptrs,
+                                mask=offs_k < chunk_size_limit - k,
+                                other=-1)
+        dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
+                       other=0.0).to(tl.float32)
         if not HAS_SEQ_IDX:
-            scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
+            scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
         else:
-            scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
+            scale = tl.where(seq_idx_k == seq_idx_last,
+                             tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
         b *= scale[:, None]
         b = b.to(x_ptr.dtype.element_ty)
         acc += tl.dot(x, b)
@@ -170,40 +311,130 @@ def _chunk_state_fwd_kernel(
     states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
     offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
     offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
-    states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)
+    states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
+                                offs_n[None, :] * stride_states_dstate)
     c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
     tl.store(states_ptrs, states, mask=c_mask)
 
+
 @triton.autotune(
     configs=[
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
-        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 256,
+                'BLOCK_SIZE_K': 64
+            },
+            num_stages=3,
+            num_warps=8),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 256,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 128,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 128,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 128,
+                'BLOCK_SIZE_N': 32,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=4),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 32,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=5,
+            num_warps=2),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 32,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=5,
+            num_warps=2),
+        triton.Config(
+            {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 32
+            },
+            num_stages=4,
+            num_warps=2),
     ],
     key=['hdim', 'dstate', 'chunk_size'],
 )
 @triton.jit
 def _chunk_state_varlen_kernel(
     # Pointers to matrices
-    x_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, chunk_states_ptr, cu_seqlens_ptr, states_ptr,
+    x_ptr,
+    b_ptr,
+    dt_ptr,
+    dA_cumsum_ptr,
+    chunk_states_ptr,
+    cu_seqlens_ptr,
+    states_ptr,
     # Matrix dimensions
-    hdim, dstate, chunk_size,
-    seqlen, nheads_ngroups_ratio,
+    hdim,
+    dstate,
+    chunk_size,
+    seqlen,
+    nheads_ngroups_ratio,
     # Strides
-    stride_x_seqlen, stride_x_head, stride_x_hdim,
-    stride_b_seqlen, stride_b_head, stride_b_dstate,
-    stride_dt_chunk, stride_dt_head, stride_dt_csize,
-    stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
-    stride_chunk_states_chunk, stride_chunk_states_head, stride_chunk_states_hdim, stride_chunk_states_dstate,
-    stride_states_batch, stride_states_head, stride_states_hdim, stride_states_dstate,
+    stride_x_seqlen,
+    stride_x_head,
+    stride_x_hdim,
+    stride_b_seqlen,
+    stride_b_head,
+    stride_b_dstate,
+    stride_dt_chunk,
+    stride_dt_head,
+    stride_dt_csize,
+    stride_dA_cs_chunk,
+    stride_dA_cs_head,
+    stride_dA_cs_csize,
+    stride_chunk_states_chunk,
+    stride_chunk_states_head,
+    stride_chunk_states_hdim,
+    stride_chunk_states_dstate,
+    stride_states_batch,
+    stride_states_head,
+    stride_states_hdim,
+    stride_states_dstate,
     # Meta-parameters
-    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
+    BLOCK_SIZE_M: tl.constexpr,
+    BLOCK_SIZE_N: tl.constexpr,
+    BLOCK_SIZE_K: tl.constexpr,
 ):
     pid_b = tl.program_id(axis=1)
     pid_h = tl.program_id(axis=2)
@@ -212,7 +443,8 @@ def _chunk_state_varlen_kernel(
     pid_n = tl.program_id(axis=0) % num_pid_n
     end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
     pid_c = (end_idx - 1) // chunk_size
-    b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
+    b_ptr += pid_c * chunk_size * stride_b_seqlen + (
+        pid_h // nheads_ngroups_ratio) * stride_b_head
     x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
     dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
     dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
@@ -221,10 +453,13 @@ def _chunk_state_varlen_kernel(
     offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
     offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
     offs_k = tl.arange(0, BLOCK_SIZE_K)
-    x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)
-    b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)
+    x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim +
+                      offs_k[None, :] * stride_x_seqlen)
+    b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate +
+                      offs_k[:, None] * stride_b_seqlen)
     dt_ptrs = dt_ptr + offs_k * stride_dt_csize
-    dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
+    dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) *
+                         stride_dA_cs_csize).to(tl.float32)
     dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
 
     chunk_size_limit = end_idx - pid_c * chunk_size
@@ -233,12 +468,24 @@ def _chunk_state_varlen_kernel(
 
     acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
     for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
-        x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k) & (offs_k[None, :] >= start_idx_cur - k), other=0.0)
-        b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32)
-        dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
-        dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
-        scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
-                         tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
+        x = tl.load(x_ptrs,
+                    mask=(offs_m[:, None] < hdim) &
+                    (offs_k[None, :] < chunk_size_limit - k) &
+                    (offs_k[None, :] >= start_idx_cur - k),
+                    other=0.0)
+        b = tl.load(b_ptrs,
+                    mask=(offs_k[:, None] < chunk_size_limit - k) &
+                    (offs_n[None, :] < dstate) &
+                    (offs_k[:, None] >= start_idx_cur - k),
+                    other=0.0).to(tl.float32)
+        dA_cs_k = tl.load(dA_cumsum_ptrs,
+                          mask=offs_k < chunk_size_limit - k,
+                          other=0.0).to(tl.float32)
+        dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
+                       other=0.0).to(tl.float32)
+        scale = tl.where(
+            (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
+            tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
         b *= scale[:, None]
         b = b.to(x_ptr.dtype.element_ty)
         acc += tl.dot(x, b)
@@ -249,8 +496,13 @@ def _chunk_state_varlen_kernel(
 
     # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
     if start_idx < pid_c * chunk_size:
-        chunk_states_ptrs = chunk_states_ptr + (offs_m[:, None] * stride_chunk_states_hdim + offs_n[None, :] * stride_chunk_states_dstate)
-        chunk_states = tl.load(chunk_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
+        chunk_states_ptrs = chunk_states_ptr + (
+            offs_m[:, None] * stride_chunk_states_hdim +
+            offs_n[None, :] * stride_chunk_states_dstate)
+        chunk_states = tl.load(chunk_states_ptrs,
+                               mask=(offs_m[:, None] < hdim) &
+                               (offs_n[None, :] < dstate),
+                               other=0.0).to(tl.float32)
         # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
         scale = tl.exp(dA_cs_last)
         acc += chunk_states * scale
@@ -260,37 +512,77 @@ def _chunk_state_varlen_kernel(
     states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
     offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
     offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
-    states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)
+    states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
+                                offs_n[None, :] * stride_states_dstate)
     c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
     tl.store(states_ptrs, states, mask=c_mask)
 
 
-def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
+def _chunk_cumsum_fwd(dt,
+                      A,
+                      chunk_size,
+                      dt_bias=None,
+                      dt_softplus=False,
+                      dt_limit=(0.0, float("inf"))):
     batch, seqlen, nheads = dt.shape
-    assert A.shape == (nheads,)
+    assert A.shape == (nheads, )
     if dt_bias is not None:
-        assert dt_bias.shape == (nheads,)
+        assert dt_bias.shape == (nheads, )
     nchunks = math.ceil(seqlen / chunk_size)
-    dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
-    dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
-    grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
+    dt_out = torch.empty(batch,
+                         nheads,
+                         nchunks,
+                         chunk_size,
+                         device=dt.device,
+                         dtype=torch.float32)
+    dA_cumsum = torch.empty(batch,
+                            nheads,
+                            nchunks,
+                            chunk_size,
+                            device=dt.device,
+                            dtype=torch.float32)
+    grid_chunk_cs = lambda META: (batch, nchunks,
+                                  triton.cdiv(nheads, META['BLOCK_SIZE_H']))
     with torch.cuda.device(dt.device.index):
         _chunk_cumsum_fwd_kernel[grid_chunk_cs](
-            dt, A, dt_bias, dt_out, dA_cumsum,
-            batch, seqlen, nheads, chunk_size,
-            dt_limit[0], dt_limit[1],
-            dt.stride(0), dt.stride(1), dt.stride(2),
+            dt,
+            A,
+            dt_bias,
+            dt_out,
+            dA_cumsum,
+            batch,
+            seqlen,
+            nheads,
+            chunk_size,
+            dt_limit[0],
+            dt_limit[1],
+            dt.stride(0),
+            dt.stride(1),
+            dt.stride(2),
             A.stride(0),
             dt_bias.stride(0) if dt_bias is not None else 0,
-            dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),
-            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
+            dt_out.stride(0),
+            dt_out.stride(2),
+            dt_out.stride(1),
+            dt_out.stride(3),
+            dA_cumsum.stride(0),
+            dA_cumsum.stride(2),
+            dA_cumsum.stride(1),
+            dA_cumsum.stride(3),
             dt_softplus,
             HAS_DT_BIAS=dt_bias is not None,
             BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
         )
     return dA_cumsum, dt_out
 
-def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True):
+
+def _chunk_state_fwd(B,
+                     x,
+                     dt,
+                     dA_cumsum,
+                     seq_idx=None,
+                     states=None,
+                     states_in_fp32=True):
     batch, seqlen, nheads, headdim = x.shape
     _, _, nchunks, chunk_size = dt.shape
     _, _, ngroups, dstate = B.shape
@@ -304,24 +596,54 @@ def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_f
         assert states.shape == (batch, nchunks, nheads, headdim, dstate)
     else:
         states_dtype = torch.float32 if states_in_fp32 else B.dtype
-        states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype)
-    grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
-                    batch * nchunks, nheads)
+        states = torch.empty((batch, nchunks, nheads, headdim, dstate),
+                             device=x.device,
+                             dtype=states_dtype)
+    grid = lambda META: (
+        triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(
+            dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads)
     with torch.cuda.device(x.device.index):
         _chunk_state_fwd_kernel[grid](
-            x, B, states, dt, dA_cumsum, seq_idx,
-            headdim, dstate, chunk_size,
-            batch, seqlen, nheads // ngroups,
-            x.stride(0), x.stride(1), x.stride(2), x.stride(3),
-            B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
-            states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
-            dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
-            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
-            *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
+            x,
+            B,
+            states,
+            dt,
+            dA_cumsum,
+            seq_idx,
+            headdim,
+            dstate,
+            chunk_size,
+            batch,
+            seqlen,
+            nheads // ngroups,
+            x.stride(0),
+            x.stride(1),
+            x.stride(2),
+            x.stride(3),
+            B.stride(0),
+            B.stride(1),
+            B.stride(2),
+            B.stride(-1),
+            states.stride(0),
+            states.stride(1),
+            states.stride(2),
+            states.stride(3),
+            states.stride(4),
+            dt.stride(0),
+            dt.stride(2),
+            dt.stride(1),
+            dt.stride(3),
+            dA_cumsum.stride(0),
+            dA_cumsum.stride(2),
+            dA_cumsum.stride(1),
+            dA_cumsum.stride(3),
+            *((seq_idx.stride(0),
+               seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
             HAS_SEQ_IDX=seq_idx is not None,
         )
     return states
 
+
 def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
     total_seqlen, nheads, headdim = x.shape
     _, nchunks, chunk_size = dt.shape
@@ -333,19 +655,47 @@ def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
     assert dt.shape == (nheads, nchunks, chunk_size)
     assert dA_cumsum.shape == dt.shape
     assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
-    states = torch.empty(batch, nheads, headdim, dstate, dtype=chunk_states.dtype, device=chunk_states.device)
-    grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
-                    batch, nheads)
+    states = torch.empty(batch,
+                         nheads,
+                         headdim,
+                         dstate,
+                         dtype=chunk_states.dtype,
+                         device=chunk_states.device)
+    grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
+                         cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads)
     with torch.cuda.device(x.device.index):
         _chunk_state_varlen_kernel[grid](
-            x, B, dt, dA_cumsum, chunk_states, cu_seqlens, states,
-            headdim, dstate, chunk_size,
-            total_seqlen, nheads // ngroups,
-            x.stride(0), x.stride(1), x.stride(2),
-            B.stride(0), B.stride(1), B.stride(2),
-            dt.stride(1), dt.stride(0), dt.stride(2),
-            dA_cumsum.stride(1), dA_cumsum.stride(0), dA_cumsum.stride(2),
-            chunk_states.stride(0), chunk_states.stride(1), chunk_states.stride(2), chunk_states.stride(3),
-            states.stride(0), states.stride(1), states.stride(2), states.stride(3),
+            x,
+            B,
+            dt,
+            dA_cumsum,
+            chunk_states,
+            cu_seqlens,
+            states,
+            headdim,
+            dstate,
+            chunk_size,
+            total_seqlen,
+            nheads // ngroups,
+            x.stride(0),
+            x.stride(1),
+            x.stride(2),
+            B.stride(0),
+            B.stride(1),
+            B.stride(2),
+            dt.stride(1),
+            dt.stride(0),
+            dt.stride(2),
+            dA_cumsum.stride(1),
+            dA_cumsum.stride(0),
+            dA_cumsum.stride(2),
+            chunk_states.stride(0),
+            chunk_states.stride(1),
+            chunk_states.stride(2),
+            chunk_states.stride(3),
+            states.stride(0),
+            states.stride(1),
+            states.stride(2),
+            states.stride(3),
         )
     return states
diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py
index 728024a6b31fa..90854fd0c0a10 100644
--- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py
+++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py
@@ -1,50 +1,67 @@
 # Copyright (c) 2024, Tri Dao, Albert Gu.
+# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_combined.py
 
+# ruff: noqa: E501
 """We want triton==2.1.0 or 2.2.0 for this
 """
 
-from packaging import version
-
 import torch
-
 import triton
-
 from einops import rearrange
+from packaging import version
 
 from .ssd_bmm import _bmm_chunk_fwd
-from .ssd_chunk_state import _chunk_cumsum_fwd
-from .ssd_chunk_state import _chunk_state_fwd
-from .ssd_chunk_state import chunk_state_varlen
-from .ssd_state_passing import _state_passing_fwd
 from .ssd_chunk_scan import _chunk_scan_fwd
+from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
+                              chunk_state_varlen)
+from .ssd_state_passing import _state_passing_fwd
 
 TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
 
+
 def init_to_zero(names):
-    return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
+    return lambda nargs: [
+        nargs[name].zero_() for name in names if nargs[name] is not None
+    ]
 
-def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
+
+def _mamba_chunk_scan_combined_fwd(x,
+                                   dt,
+                                   A,
+                                   B,
+                                   C,
+                                   chunk_size,
+                                   D=None,
+                                   z=None,
+                                   dt_bias=None,
+                                   initial_states=None,
+                                   seq_idx=None,
+                                   cu_seqlens=None,
+                                   dt_softplus=False,
+                                   dt_limit=(0.0, float("inf"))):
     batch, seqlen, nheads, headdim = x.shape
     _, _, ngroups, dstate = B.shape
     assert nheads % ngroups == 0
     assert B.shape == (batch, seqlen, ngroups, dstate)
     assert x.shape == (batch, seqlen, nheads, headdim)
     assert dt.shape == (batch, seqlen, nheads)
-    assert A.shape == (nheads,)
+    assert A.shape == (nheads, )
     assert C.shape == B.shape
     if z is not None:
         assert z.shape == x.shape
     if D is not None:
-        assert D.shape == (nheads, headdim) or D.shape == (nheads,)
+        assert D.shape == (nheads, headdim) or D.shape == (nheads, )
     if seq_idx is not None:
         assert seq_idx.shape == (batch, seqlen)
     if B.stride(-1) != 1:
         B = B.contiguous()
     if C.stride(-1) != 1:
         C = C.contiguous()
-    if x.stride(-1) != 1 and x.stride(1) != 1:  # Either M or K dimension should be contiguous
+    if x.stride(-1) != 1 and x.stride(
+            1) != 1:  # Either M or K dimension should be contiguous
         x = x.contiguous()
-    if z is not None and z.stride(-1) != 1 and z.stride(1) != 1:  # Either M or K dimension should be contiguous
+    if z is not None and z.stride(-1) != 1 and z.stride(
+            1) != 1:  # Either M or K dimension should be contiguous
         z = z.contiguous()
     if D is not None and D.stride(-1) != 1:
         D = D.contiguous()
@@ -54,28 +71,73 @@ def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, d
     # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
     # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
     # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
-    dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
-    states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
+    dA_cumsum, dt = _chunk_cumsum_fwd(dt,
+                                      A,
+                                      chunk_size,
+                                      dt_bias=dt_bias,
+                                      dt_softplus=dt_softplus,
+                                      dt_limit=dt_limit)
+    states = _chunk_state_fwd(B,
+                              x,
+                              dt,
+                              dA_cumsum,
+                              seq_idx=seq_idx,
+                              states_in_fp32=True)
     # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
     # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
     # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
-    states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
-                                              initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None,
-                                              seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype)
-    states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]]
+    states, final_states = _state_passing_fwd(
+        rearrange(states, "... p n -> ... (p n)"),
+        dA_cumsum[:, :, :, -1],
+        initial_states=rearrange(initial_states, "... p n -> ... (p n)")
+        if initial_states is not None else None,
+        seq_idx=seq_idx,
+        chunk_size=chunk_size,
+        out_dtype=C.dtype)
+    states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
+                            for t in [states, final_states])
     # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
     # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
-    CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
-    out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx)
+    CB = _bmm_chunk_fwd(C,
+                        B,
+                        chunk_size,
+                        seq_idx=seq_idx,
+                        output_dtype=torch.float32)
+    out, out_x = _chunk_scan_fwd(CB,
+                                 x,
+                                 dt,
+                                 dA_cumsum,
+                                 C,
+                                 states,
+                                 D=D,
+                                 z=z,
+                                 seq_idx=seq_idx)
     if cu_seqlens is None:
         return out, out_x, dt, dA_cumsum, states, final_states
     else:
         assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
-        varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), dt.squeeze(0), dA_cumsum.squeeze(0),
+        varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0),
+                                           dt.squeeze(0), dA_cumsum.squeeze(0),
                                            cu_seqlens, states.squeeze(0))
         return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
 
-def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False):
+
+def mamba_chunk_scan_combined(x,
+                              dt,
+                              A,
+                              B,
+                              C,
+                              chunk_size,
+                              D=None,
+                              z=None,
+                              dt_bias=None,
+                              initial_states=None,
+                              seq_idx=None,
+                              cu_seqlens=None,
+                              dt_softplus=False,
+                              dt_limit=(0.0, float("inf")),
+                              return_final_states=False,
+                              return_varlen_states=False):
     """
     Argument:
         x: (batch, seqlen, nheads, headdim)
@@ -99,9 +161,26 @@ def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bia
         cu_seqlens = None
     else:
         assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
-    out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
+    out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(
+        x,
+        dt,
+        A,
+        B,
+        C,
+        chunk_size,
+        D=D,
+        z=z,
+        dt_bias=dt_bias,
+        initial_states=initial_states,
+        seq_idx=seq_idx,
+        cu_seqlens=cu_seqlens,
+        dt_softplus=dt_softplus,
+        dt_limit=dt_limit)
     if not return_varlen_states:
         return out if not return_final_states else (out, final_states)
     else:
         varlen_states = rest[0]
-        return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states)
\ No newline at end of file
+        return (out,
+                varlen_states) if not return_final_states else (out,
+                                                                final_states,
+                                                                varlen_states)
diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
index 59ed1d17cfda2..dfc87fc7e5c68 100644
--- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
+++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
@@ -1,10 +1,11 @@
 # Copyright (c) 2024, Tri Dao, Albert Gu.
+# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_state_passing.py
 
+# ruff: noqa: E501
 """We want triton==2.1.0 or 2.2.0 for this
 """
 
 import torch
-
 import triton
 import triton.language as tl
 
@@ -23,16 +24,37 @@
 @triton.jit
 def _state_passing_fwd_kernel(
     # Pointers to matrices
-    states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,
+    states_ptr,
+    out_ptr,
+    final_states_ptr,
+    dA_cs_ptr,
+    initstates_ptr,
+    seq_idx_ptr,
     # Matrix dimensions
-    dim, nchunks, seqlen, chunk_size,
+    dim,
+    nchunks,
+    seqlen,
+    chunk_size,
     # Strides
-    stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,
-    stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
-    stride_final_states_batch, stride_final_states_head, stride_final_states_dim,
-    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
-    stride_initstates_batch, stride_initstates_head, stride_initstates_dim,
-    stride_seq_idx_batch, stride_seq_idx_seqlen,
+    stride_states_batch,
+    stride_states_chunk,
+    stride_states_head,
+    stride_states_dim,
+    stride_out_batch,
+    stride_out_chunk,
+    stride_out_head,
+    stride_out_dim,
+    stride_final_states_batch,
+    stride_final_states_head,
+    stride_final_states_dim,
+    stride_dA_cs_batch,
+    stride_dA_cs_chunk,
+    stride_dA_cs_head,
+    stride_initstates_batch,
+    stride_initstates_head,
+    stride_initstates_dim,
+    stride_seq_idx_batch,
+    stride_seq_idx_seqlen,
     # Meta-parameters
     HAS_INITSTATES: tl.constexpr,
     HAS_SEQ_IDX: tl.constexpr,
@@ -59,16 +81,20 @@ def _state_passing_fwd_kernel(
         states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
     else:
         initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim
-        states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
+        states = tl.load(initstates_ptrs, mask=offs_m < dim,
+                         other=0.0).to(tl.float32)
     tl.store(out_ptrs, states, mask=offs_m < dim)
     out_ptrs += stride_out_chunk
     seq_idx = 0
     for c in range(nchunks):
-        new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
+        new_states = tl.load(states_ptrs, mask=offs_m < dim,
+                             other=0.0).to(tl.float32)
         dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
         scale = tl.exp(dA_cs)
         if HAS_SEQ_IDX:
-            seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
+            seq_idx_new = tl.load(seq_idx_ptr +
+                                  (min((c + 1) * chunk_size, seqlen) - 1) *
+                                  stride_seq_idx_seqlen)
             scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
             seq_idx = seq_idx_new
         states = scale * states + new_states
@@ -81,7 +107,11 @@ def _state_passing_fwd_kernel(
         out_ptrs += stride_out_chunk
 
 
-def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,
+def _state_passing_fwd(states,
+                       dA_chunk_cumsum,
+                       initial_states=None,
+                       seq_idx=None,
+                       chunk_size=None,
                        out_dtype=None):
     batch, nchunks, nheads, dim = states.shape
     assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
@@ -92,20 +122,44 @@ def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=Non
         seqlen = seq_idx.shape[-1]
         assert seq_idx.shape == (batch, seqlen)
     out_dtype = states.dtype if out_dtype is None else out_dtype
-    out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)
-    final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)
+    out = torch.empty((batch, nchunks, nheads, dim),
+                      device=states.device,
+                      dtype=out_dtype)
+    final_states = torch.empty((batch, nheads, dim),
+                               device=states.device,
+                               dtype=torch.float32)
     grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
     with torch.cuda.device(states.device.index):
         _state_passing_fwd_kernel[grid](
-            states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,
-            dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
-            states.stride(0), states.stride(1), states.stride(2), states.stride(3),
-            out.stride(0), out.stride(1), out.stride(2), out.stride(3),
-            final_states.stride(0), final_states.stride(1), final_states.stride(2),
-            dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
-            *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
-              if initial_states is not None else (0, 0, 0)),
-            *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
+            states,
+            out,
+            final_states,
+            dA_chunk_cumsum,
+            initial_states,
+            seq_idx,
+            dim,
+            nchunks,
+            seqlen if seq_idx is not None else 0,
+            chunk_size if seq_idx is not None else 0,
+            states.stride(0),
+            states.stride(1),
+            states.stride(2),
+            states.stride(3),
+            out.stride(0),
+            out.stride(1),
+            out.stride(2),
+            out.stride(3),
+            final_states.stride(0),
+            final_states.stride(1),
+            final_states.stride(2),
+            dA_chunk_cumsum.stride(0),
+            dA_chunk_cumsum.stride(2),
+            dA_chunk_cumsum.stride(1),
+            *((initial_states.stride(0), initial_states.stride(1),
+               initial_states.stride(2)) if initial_states is not None else
+              (0, 0, 0)),
+            *((seq_idx.stride(0),
+               seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
             HAS_INITSTATES=initial_states is not None,
             HAS_SEQ_IDX=seq_idx is not None,
         )
diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py
index 5c6a8ab043170..2693c45b27520 100644
--- a/vllm/model_executor/models/bamba.py
+++ b/vllm/model_executor/models/bamba.py
@@ -10,16 +10,16 @@
 from vllm.attention.layer import Attention
 from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
 from vllm.distributed import divide, get_tensor_model_parallel_world_size
+from vllm.model_executor.layers.activation import SiluAndMul
 from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (QKVParallelLinear,
-                                               MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
+                                               QKVParallelLinear,
                                                RowParallelLinear)
-from vllm.model_executor.layers.activation import SiluAndMul
 from vllm.model_executor.layers.logits_processor import LogitsProcessor
 from vllm.model_executor.layers.mamba.mamba_mixer2 import (
     MambaMixer2, extra_groups_for_head_shards)
-from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
 from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
 from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
 from vllm.model_executor.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@@ -67,6 +67,7 @@ def forward(self, x):
         x, _ = self.down_proj(x)
         return x
 
+
 class BambaMixerDecoderLayer(nn.Module):
 
     def __init__(self,
@@ -161,7 +162,7 @@ def __init__(
             max_position_embeddings=max_position_embeddings,
             base=rope_theta,
             is_neox_style=True,
-            dtype=torch.get_default_dtype(), # see impl of get_rope
+            dtype=torch.get_default_dtype(),  # see impl of get_rope
         )
 
         self.qkv_proj = QKVParallelLinear(
@@ -203,23 +204,28 @@ def self_attention(
         qkv, _ = self.qkv_proj(hidden_states)
         q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
 
-        # because the bamba model may potentially handle long sequences, 
-        # we should adjust the sin_cos cache if necesary to avoid out of bounds
+        # because the bamba model may potentially handle long sequences,
+        # we should adjust the sin_cos cache if necessary to avoid out of bounds
         # - first get the max_position
         max_position = max(
             getattr(attn_metadata, 'max_prefill_seq_len', 0),
             getattr(attn_metadata, 'max_decode_seq_len', 0),
         )
         if max_position == 0:
-            # if we cannot get the max lenght from the metadata, then
-            # get it frmo the positions
+            # if we cannot get the max length from the metadata, then
+            # get it from the positions
             max_position = positions.max().item()
 
-        if self.rotary_emb.max_position_embeddings <= max_position:
+        # when VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 could potentially cause inputs
+        # longer than max_position_embeddings. We extend the rope cache
+        # to prevent CUDA errors. Be aware that the outputs could be of
+        # lower quality for long sequence lengths.
+        rotary = self.rotary_emb
+        if rotary.max_position_embeddings <= max_position:
             # we set it to the next power of two that covers it
-            while self.rotary_emb.max_position_embeddings <= max_position:
-                self.rotary_emb.max_position_embeddings *= 2
-            self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache()
+            while rotary.max_position_embeddings <= max_position:
+                rotary.max_position_embeddings *= 2
+            rotary.cos_sin_cache = rotary._compute_cos_sin_cache()
 
         q, k = self.rotary_emb(positions, q, k)
         attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
@@ -260,6 +266,7 @@ def forward(
     "mamba": BambaMixerDecoderLayer
 }
 
+
 class BambaModel(nn.Module):
 
     def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -312,10 +319,11 @@ def forward(
         # add additional attn_metadata for the mixer layers
         if attn_metadata.num_prefills > 0:
             sed_idx = torch.zeros_like(input_ids, dtype=torch.int32)
-            for i, (srt, end) in enumerate(zip(
-                attn_metadata.query_start_loc,
-                attn_metadata.query_start_loc[1:],
-            )):
+            for i, (srt, end) in enumerate(
+                    zip(
+                        attn_metadata.query_start_loc,
+                        attn_metadata.query_start_loc[1:],
+                    )):
                 sed_idx[srt:end] = i
 
             attn_metadata.seq_idx = sed_idx
@@ -335,7 +343,8 @@ def forward(
 
             layer_mamba_cache_params = None
             if isinstance(layer, BambaMixerDecoderLayer):
-                layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i - num_attn)
+                layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
+                    i - num_attn)
 
             hidden_states, residual = layer(
                 positions=positions,
@@ -457,18 +466,14 @@ def _get_mamba_cache_shape(
 
         intermediate_size = self.config.mamba_expand * hidden_size
 
-        # if n_groups is not divisible by world_size, need to extend the shards to ensure
-        # all groups needed by a head is sharded along with it
-        n_groups = (
-            self.config.mamba_n_groups + 
-            extra_groups_for_head_shards(self.config.mamba_n_groups, world_size)
-        )
+        # if n_groups is not divisible by world_size, need to extend the shards
+        # to ensure all groups needed by a head is sharded along with it
+        n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards(
+            self.config.mamba_n_groups, world_size))
 
         # - heads and n_groups are TP-ed
-        conv_dim = (
-            intermediate_size + 
-            2 * n_groups * self.config.mamba_d_state
-        )
+        conv_dim = (intermediate_size +
+                    2 * n_groups * self.config.mamba_d_state)
         conv_state_shape = (
             divide(conv_dim, world_size),
             self.config.mamba_d_conv - 1,