From b69b957e6ba2910ca673c810d8b47ae080feccf6 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Fri, 19 Jul 2024 20:46:34 +0800 Subject: [PATCH] migrate to tridao's native varlen causal_conv1d kernel for speedup --- mamba_ssm/models/mixer_seq_simple.py | 12 +- mamba_ssm/modules/block.py | 2 +- mamba_ssm/modules/mamba_simple.py | 105 +++---------- mamba_ssm/ops/selective_scan_interface.py | 147 +++++++++--------- .../ops/test_mamba_cu_seqlens_equivalence.py | 28 +++- 5 files changed, 122 insertions(+), 172 deletions(-) diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index 06ecb183..fae2257a 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -16,7 +16,6 @@ from mamba_ssm.modules.mamba2 import Mamba2 from mamba_ssm.modules.mha import MHA from mamba_ssm.modules.mlp import GatedMLP -from mamba_ssm.modules.mamba_simple import Block as Block_Mamba1 from mamba_ssm.modules.block import Block from mamba_ssm.utils.generation import GenerationMixin from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf @@ -71,8 +70,7 @@ def create_block( mlp_cls = partial( GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs ) - block_cls = Block if ssm_layer == "Mamba2" else Block_Mamba1 - block = block_cls( + block = Block( d_model, mixer_cls, mlp_cls, @@ -189,12 +187,12 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs) for i, layer in enumerate(self.layers) } - def forward(self, input_ids, cu_seqlens=None, inference_params=None, **mixer_kwargs): + def forward(self, input_ids, inference_params=None, **mixer_kwargs): hidden_states = self.embedding(input_ids) residual = None for layer in self.layers: hidden_states, residual = layer( - hidden_states, residual, cu_seqlens=cu_seqlens, inference_params=inference_params, **mixer_kwargs + hidden_states, residual, inference_params=inference_params, **mixer_kwargs ) if not self.fused_add_norm: residual = (hidden_states + residual) if residual is not None else hidden_states @@ -273,12 +271,12 @@ def tie_weights(self): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, cu_seqlens=None, **mixer_kwargs): + def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs): """ "position_ids" is just to be compatible with Transformer generation. We don't use it. num_last_tokens: if > 0, only return the logits for the last n tokens """ - hidden_states = self.backbone(input_ids, cu_seqlens=cu_seqlens, inference_params=inference_params, **mixer_kwargs) + hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs) if num_last_tokens > 0: hidden_states = hidden_states[:, -num_last_tokens:] lm_logits = self.lm_head(hidden_states) diff --git a/mamba_ssm/modules/block.py b/mamba_ssm/modules/block.py index 5fdaff05..b0ed44e1 100644 --- a/mamba_ssm/modules/block.py +++ b/mamba_ssm/modules/block.py @@ -88,4 +88,4 @@ def forward( return hidden_states, residual def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) \ No newline at end of file + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 7bf976c4..d27ec31d 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -10,7 +10,7 @@ from einops import rearrange, repeat -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, selective_scan_ref +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -119,10 +119,12 @@ def __init__( def forward(self, hidden_states, cu_seqlens=None, inference_params=None): """ hidden_states: (B, L, D) - cu_seqlens: one-dimensional tensor representing cumulative start indexes of packed sequence, a.k.a., B=1 + cu_seqlens: (Optional) cumulative sum of the sequence lengths, starting from 0 and end with L, and must already be sorted. Returns: same shape as hidden_states """ batch, seqlen, dim = hidden_states.shape + if cu_seqlens is not None: + assert batch == 1 and cu_seqlens.ndimension() == 1, "varlen mamba1 is only supported with B=1" conv_state, ssm_state = None, None if inference_params is not None: @@ -158,46 +160,40 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None): self.D.float(), delta_bias=self.dt_proj.bias.float(), delta_softplus=True, - cu_seqlens=cu_seqlens, - d_conv=torch.tensor(self.d_conv) + cu_seqlens=cu_seqlens ) else: x, z = xz.chunk(2, dim=1) - - # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences - if cu_seqlens is not None: - padded_x = x - count = 0 - for idx in cu_seqlens[1:-1].tolist(): - padded_idx = idx + count*(self.d_conv - 1) - padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], self.d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) - count = count + 1 - x = padded_x - # Compute short convolution if conv_state is not None: # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) + if cu_seqlens is not None: + # naive pure python implementation of varlen causal_conv1d + for i, s in enumerate(cu_seqlens[1:-1]): + x = torch.cat((x[..., :s + i*(self.d_conv - 1)], torch.zeros_like(x[..., :(self.d_conv - 1)]), x[..., s + i*(self.d_conv - 1):]), dim=2) + mask = torch.cat([torch.cat((torch.full((s,), True, dtype=torch.bool, device=x.device), + torch.full((self.d_conv - 1,), False, dtype=torch.bool, device=x.device)), dim=0) + for s in (cu_seqlens[1:] - cu_seqlens[:-1])], dim=0) + x = self.act(self.conv1d(x)[:, :, mask]) + else: + x = self.act(self.conv1d(x)[..., :seqlen]) else: assert self.activation in ["silu", "swish"] + if cu_seqlens is not None: + seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0) + else: + seq_idx = None x = causal_conv1d_fn( - x=x, + x=x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, + seq_idx=seq_idx, activation=self.activation, ) - - # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences - if cu_seqlens is not None: - mask = [] - for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): - mask.extend([True] * seq_len) - mask.extend([False] * (self.d_conv - 1)) - mask = mask[:-(self.d_conv - 1)] - x = x[:, :, mask] # We're careful here about the layout, to avoid extra transposes. # We want dt to have d as the slowest moving dimension @@ -208,7 +204,6 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None): dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - assert self.activation in ["silu", "swish"] y = selective_scan_fn( x, @@ -317,59 +312,3 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states conv_state.zero_() ssm_state.zero_() return conv_state, ssm_state - -class Block(nn.Module): - def __init__( - self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False - ): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.mixer = mixer_cls(dim) - self.norm = norm_cls(dim) - if self.fused_add_norm: - assert RMSNorm is not None, "RMSNorm import fails" - assert isinstance( - self.norm, (nn.LayerNorm, RMSNorm) - ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - - def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, cu_seqlens=None, inference_params=None - ): - r"""Pass the input through the encoder layer. - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Mixer(LN(residual)) - """ - if not self.fused_add_norm: - residual = (hidden_states + residual) if residual is not None else hidden_states - hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn - hidden_states, residual = fused_add_norm_fn( - hidden_states, - self.norm.weight, - self.norm.bias, - residual=residual, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - eps=self.norm.eps, - ) - hidden_states = self.mixer(hidden_states, cu_seqlens=cu_seqlens, inference_params=inference_params) - return hidden_states, residual - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) \ No newline at end of file diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index d641380e..fa43b38b 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -169,12 +169,19 @@ class MambaInnerFn(torch.autograd.Function): def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None, checkpoint_lvl=1): + C_proj_bias=None, delta_softplus=True, cu_seqlens=None, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." assert checkpoint_lvl in [0, 1] + + if cu_seqlens is not None: + seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0) + else: + seq_idx = None + L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) @@ -188,28 +195,18 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) - - # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences - if cu_seqlens is not None: - padded_x = x - count = 0 - for idx in cu_seqlens[1:-1].tolist(): - padded_idx = idx + count*(d_conv - 1) - padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) - count = count + 1 - x = padded_x - - conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True) - - # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences - if cu_seqlens is not None: - mask = [] - for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): - mask.extend([True] * seq_len) - mask.extend([False] * (d_conv - 1)) - mask = mask[:-(d_conv - 1)] - conv1d_out = conv1d_out[:, :, mask] + conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( + x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, + conv1d_weight, + conv1d_bias, + seq_idx, + None, + None, + True + ) + if conv1d_out.stride(-1) != 1: + conv1d_out = conv1d_out.contiguous() # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension @@ -248,7 +245,16 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh if D is not None: D = D.contiguous() out, scan_intermediates, out_z = selective_scan_cuda.fwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens + conv1d_out, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + cu_seqlens ) ctx.delta_softplus = delta_softplus ctx.out_proj_bias_is_None = out_proj_bias is None @@ -257,7 +263,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, - A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, d_conv) + A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, seq_idx) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @@ -266,40 +272,27 @@ def backward(ctx, dout): # dout: (batch, seqlen, dim) assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, - conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, d_conv) = ctx.saved_tensors + conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, seq_idx) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) if dout.stride(-1) != 1: dout = dout.contiguous() - - x_bak = x if ctx.checkpoint_lvl == 1: - # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences - if cu_seqlens is not None: - padded_x = x - count = 0 - for idx in cu_seqlens[1:-1].tolist(): - padded_idx = idx + count*(d_conv - 1) - padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) - count = count + 1 - x = padded_x - - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True) - - # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences - if cu_seqlens is not None: - mask = [] - for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): - mask.extend([True] * seq_len) - mask.extend([False] * (d_conv - 1)) - mask = mask[:-(d_conv - 1)] - conv1d_out = conv1d_out[:, :, mask] - + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( + x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, + conv1d_weight, + conv1d_bias, + seq_idx, + None, + None, + True + ) + if conv1d_out.stride(-1) != 1: + conv1d_out = conv1d_out.contiguous() delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) - x = x_bak # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). @@ -345,33 +338,42 @@ def backward(ctx, dout): # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( - x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True + x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, + conv1d_weight, + conv1d_bias, + dconv1d_out, + seq_idx, + None, + None, + dx.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else dx, + False, + True ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") - return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, + return (torch.cat((dx, dz), dim=1) if cu_seqlens is not None else dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, dout_proj_weight, dout_proj_bias, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, - dB_proj_bias, dC_proj_bias, None, None, None) + dB_proj_bias, dC_proj_bias, None, None) def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None + C_proj_bias=None, delta_softplus=True, cu_seqlens=None ): return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, - A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, cu_seqlens, d_conv) + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, cu_seqlens) def mamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None + C_proj_bias=None, delta_softplus=True, cu_seqlens=None ): assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." L = xz.shape[-1] @@ -379,26 +381,19 @@ def mamba_inner_ref( d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) - # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences if cu_seqlens is not None: - padded_x = x - count = 0 - for idx in cu_seqlens[1:-1].tolist(): - padded_idx = idx + count*(d_conv - 1) - padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) - count = count + 1 - x = padded_x - - x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu") - - # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences - if cu_seqlens is not None: - mask = [] - for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): - mask.extend([True] * seq_len) - mask.extend([False] * (d_conv - 1)) - mask = mask[:-(d_conv - 1)] - x = x[:, :, mask] + seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0) + else: + seq_idx = None + + x = causal_conv1d_fn( + x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, + rearrange(conv1d_weight, "d 1 w -> d w"), + conv1d_bias, + seq_idx=seq_idx, + activation="silu" + ) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension @@ -422,5 +417,5 @@ def mamba_inner_ref( C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() - y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) + y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True, cu_seqlens=cu_seqlens) return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) \ No newline at end of file diff --git a/tests/ops/test_mamba_cu_seqlens_equivalence.py b/tests/ops/test_mamba_cu_seqlens_equivalence.py index 01d066ef..1937421f 100644 --- a/tests/ops/test_mamba_cu_seqlens_equivalence.py +++ b/tests/ops/test_mamba_cu_seqlens_equivalence.py @@ -1,3 +1,4 @@ +import copy import random import torch @@ -82,7 +83,7 @@ def main(): assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1] # creat one simple mamba block - mamba = Mamba( + mamba_ref = Mamba( # This module uses roughly 3 * expand * d_model^2 parameters d_model=hidden_dim, # Model dimension d_model d_state=16, # SSM state expansion factor @@ -91,17 +92,34 @@ def main(): ).to(device) # reference output for forwardding hidden_states - out_ref = mamba(hidden_states) - out_ref = pack(out_ref, cu_seqlens).unsqueeze(0) + out_ref_original = mamba_ref(hidden_states) + out_ref = pack(out_ref_original, cu_seqlens).unsqueeze(0) # output for forwardding packed_hidden_states with cu_seqlens + mamba = copy.deepcopy(mamba_ref) out = mamba(packed_hidden_states, cu_seqlens) # Testing the max/mean diff - print(f'Output max diff: {(out - out_ref).abs().max().item()}') - print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + print(f'Output max diff for output in varlen_mamba fwd pass: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff for output in varlen_mamba fwd pass: {(out - out_ref).abs().mean().item()}') assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + # bwd for mamba w/ cu_seqlens + g = torch.randn_like(out) + out.backward(g) + mamba_grad = {name: param.grad.clone() for name, param in mamba.named_parameters()} + + # bwd for mamba wo/ cu_seqlens + g_ref = unpack(g, cu_seqlens) + out_ref_original.backward(g_ref) + mamba_ref_grad = {name: param.grad.clone() for name, param in mamba_ref.named_parameters()} + + # check bwd pass + assert set(mamba_grad.keys()) == set(mamba_ref_grad.keys()) + for name in mamba_ref_grad: + print(f'Output max diff for {name} in varlen_mamba bwd pass: {( - mamba_ref_grad[name]).abs().max().item()}') + print(f'Output mean diff for {name} in varlen_mamba bwd pass: {(mamba_grad[name] - mamba_ref_grad[name]).abs().mean().item()}') + assert torch.allclose(mamba_grad[name], mamba_ref_grad[name], rtol=rtol, atol=atol) if __name__ == "__main__": main() \ No newline at end of file