From d28e1b0d7b4a665e52e385d570c81225d7d7b6c6 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Fri, 8 Mar 2024 15:38:51 +0800 Subject: [PATCH 01/27] add cu_seqlens support and ensure numerical equality --- csrc/selective_scan/selective_scan.cpp | 28 ++++-- csrc/selective_scan/selective_scan.h | 4 + .../selective_scan_bwd_kernel.cuh | 31 ++++++ .../selective_scan_fwd_kernel.cuh | 30 ++++++ mamba_ssm/modules/mamba_simple.py | 30 +++++- mamba_ssm/ops/selective_scan_interface.py | 98 +++++++++++++++---- 6 files changed, 193 insertions(+), 28 deletions(-) diff --git a/csrc/selective_scan/selective_scan.cpp b/csrc/selective_scan/selective_scan.cpp index cde867cd..d545547f 100644 --- a/csrc/selective_scan/selective_scan.cpp +++ b/csrc/selective_scan/selective_scan.cpp @@ -79,7 +79,9 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, void* delta_bias_ptr, void* x_ptr, bool has_z, - bool delta_softplus) { + bool delta_softplus, + void* cu_seqlens_ptr, + const int cu_seqlens_size) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -109,6 +111,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.x_ptr = x_ptr; params.z_ptr = has_z ? z.data_ptr() : nullptr; params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + + params.cu_seqlens_ptr = cu_seqlens_ptr; + params.cu_seqlens_size = cu_seqlens_size; + // All stride are in elements, not bytes. params.A_d_stride = A.stride(0); params.A_dstate_stride = A.stride(1); @@ -173,7 +179,9 @@ void set_ssm_params_bwd(SSMParamsBwd ¶ms, void* ddelta_bias_ptr, bool has_z, bool delta_softplus, - bool recompute_out_z) { + bool recompute_out_z, + void* cu_seqlens_ptr, + const int cu_seqlens_size) { // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, u, delta, A, B, C, has_z ? out : dout, @@ -181,7 +189,7 @@ void set_ssm_params_bwd(SSMParamsBwd ¶ms, // If not recompute_out_z, pass dout instead of out_z. // This won't be used by the bwd kernel recompute_out_z ? out_z : dout, - D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus); + D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus, cu_seqlens_ptr, cu_seqlens_size); if (!recompute_out_z) { params.out_z_ptr = nullptr; } // Set the pointers and strides. @@ -229,7 +237,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, const c10::optional &D_, const c10::optional &z_, const c10::optional &delta_bias_, - bool delta_softplus) { + bool delta_softplus, + const c10::optional &cu_seqlens_) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -319,7 +328,9 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, x.data_ptr(), has_z, - delta_softplus); + delta_softplus, + cu_seqlens_.has_value() ? cu_seqlens_.value().data_ptr() : nullptr, + cu_seqlens_.has_value() ? cu_seqlens_.value().size(0) : 0); // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing @@ -346,7 +357,8 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, const c10::optional &out_, c10::optional &dz_, bool delta_softplus, - bool recompute_out_z) { + bool recompute_out_z, + const c10::optional &cu_seqlens_) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -474,7 +486,9 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, dout, du, ddelta, dA, dB, dC, dz, D_.has_value() ? dD.data_ptr() : nullptr, delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, - has_z, delta_softplus, recompute_out_z); + has_z, delta_softplus, recompute_out_z, + cu_seqlens_.has_value() ? cu_seqlens_.value().data_ptr() : nullptr, + cu_seqlens_.has_value() ? cu_seqlens_.value().size(0) : 0); // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing diff --git a/csrc/selective_scan/selective_scan.h b/csrc/selective_scan/selective_scan.h index e2c7bcdb..62550e8c 100644 --- a/csrc/selective_scan/selective_scan.h +++ b/csrc/selective_scan/selective_scan.h @@ -33,6 +33,8 @@ struct SSMParamsBase { bool delta_softplus; + int cu_seqlens_size; + index_t A_d_stride; index_t A_dstate_stride; index_t B_batch_stride; @@ -66,6 +68,8 @@ struct SSMParamsBase { void *__restrict__ x_ptr; void *__restrict__ z_ptr; void *__restrict__ out_z_ptr; + + void *__restrict__ cu_seqlens_ptr; }; struct SSMParamsBwd: public SSMParamsBase { diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index 2ed10114..1d7b1996 100644 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -136,6 +136,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; float dD_val = 0; float ddelta_bias_val = 0; + long *cu_seqlens = reinterpret_cast(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride constexpr int kChunkSize = kNThreads * kNItems; u += (params.n_chunks - 1) * kChunkSize; @@ -245,7 +246,22 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { #pragma unroll for (int i = 0; i < kNItems; ++i) { const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + + // Reset A bar for cumulative sequences (Real) + int left = 1; + int right = params.cu_seqlens_size - 2; + while (left <= right) { + if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { + delta_a_exp = 0.f; + } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { + left = ((left + right) >> 1) + 1; + } else { + right = ((left + right) >> 1) - 1; + } + } + thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); + if (i == 0) { smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; } else { @@ -332,6 +348,21 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { for (int i = 0; i < kNItems; ++i) { // Pytorch's implementation of complex exp (which calls thrust) is very slow complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled); + + // Reset A bar for cumulative sequences (Complex) + int left = 1; + int right = params.cu_seqlens_size - 2; + while (left <= right) { + if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { + delta_a_exp.real_ = 0.f; + delta_a_exp.imag_ = 0.f; + } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { + left = ((left + right) >> 1) + 1; + } else { + right = ((left + right) >> 1) - 1; + } + } + weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); if (i == 0) { diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 440a2091..131ca720 100644 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -107,6 +107,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; + long *cu_seqlens = reinterpret_cast(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { @@ -215,6 +216,20 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if constexpr (!kIsComplex) { thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + + // Reset A bar for cumulative sequences (Real) + int left = 1; + int right = params.cu_seqlens_size - 2; + while (left <= right) { + if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { + thread_data[i].x = 0.f; + } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { + left = ((left + right) >> 1) + 1; + } else { + right = ((left + right) >> 1) - 1; + } + } + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { thread_data[i] = make_float2(1.f, 0.f); @@ -225,6 +240,21 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]); weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]; thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + + // Reset A bar for cumulative sequences (Complex) + int left = 1; + int right = params.cu_seqlens_size - 2; + while (left <= right) { + if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { + thread_data[i].x = 0.f; + thread_data[i].y = 0.f; + } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { + left = ((left + right) >> 1) + 1; + } else { + right = ((left + right) >> 1) - 1; + } + } + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f); diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 98d97a57..a877558b 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -10,7 +10,7 @@ from einops import rearrange, repeat -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, selective_scan_ref try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -116,7 +116,7 @@ def __init__( self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - def forward(self, hidden_states, inference_params=None): + def forward(self, hidden_states, cu_seqlens=None, inference_params=None): """ hidden_states: (B, L, D) Returns: same shape as hidden_states @@ -157,9 +157,22 @@ def forward(self, hidden_states, inference_params=None): self.D.float(), delta_bias=self.dt_proj.bias.float(), delta_softplus=True, + cu_seqlens=cu_seqlens[0] if cu_seqlens is not None else None, ) else: x, z = xz.chunk(2, dim=1) + + # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences + if cu_seqlens is not None: + padded_x = x + count = 0 + for idx in cu_seqlens[0][1:-1].tolist(): + padded_idx = idx + count*(self.d_conv - 1) + padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], self.d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) + count = count + 1 + x = padded_x + assert x.shape[2] == (self.d_conv - 1) * len(cu_seqlens[0][1:-1]) + z.shape[2] + # Compute short convolution if conv_state is not None: # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv @@ -175,6 +188,17 @@ def forward(self, hidden_states, inference_params=None): bias=self.conv1d.bias, activation=self.activation, ) + + # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences + if cu_seqlens is not None: + mask = [] + for seq_len in (cu_seqlens[0][1:] - cu_seqlens[0][:-1]).tolist(): + mask.extend([True] * seq_len) + mask.extend([False] * (self.d_conv - 1)) + mask = mask[:-(self.d_conv - 1)] + assert x.shape[2] == len(mask) + x = x[:, :, mask] + assert x.shape[2] == z.shape[2] # We're careful here about the layout, to avoid extra transposes. # We want dt to have d as the slowest moving dimension @@ -185,6 +209,7 @@ def forward(self, hidden_states, inference_params=None): dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + assert self.activation in ["silu", "swish"] y = selective_scan_fn( x, @@ -197,6 +222,7 @@ def forward(self, hidden_states, inference_params=None): delta_bias=self.dt_proj.bias.float(), delta_softplus=True, return_last_state=ssm_state is not None, + cu_seqlens=cu_seqlens[0] if cu_seqlens is not None else None, ) if ssm_state is not None: y, last_state = y diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index b8f14dd0..7e8f6969 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -15,7 +15,7 @@ class SelectiveScanFn(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, cu_seqlens=None): if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: @@ -34,26 +34,26 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if not ctx.has_z: - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x, cu_seqlens) return out if not return_last_state else (out, last_state) else: - ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out, cu_seqlens) out_z = rest[0] return out_z if not return_last_state else (out_z, last_state) @staticmethod def backward(ctx, dout, *args): if not ctx.has_z: - u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + u, delta, A, B, C, D, delta_bias, x, cu_seqlens = ctx.saved_tensors z = None out = None else: - u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + u, delta, A, B, C, D, z, delta_bias, x, out, cu_seqlens = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the @@ -61,7 +61,8 @@ def backward(ctx, dout, *args): # Here we just pass in None and dz will be allocated in the C++ code. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, - False # option to recompute out_z, not used here + False, # option to recompute out_z, not used here + cu_seqlens ) dz = rest[0] if ctx.has_z else None dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB @@ -71,20 +72,21 @@ def backward(ctx, dout, *args): dz, ddelta_bias if delta_bias is not None else None, None, + None, None) def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, cu_seqlens=None): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, cu_seqlens) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, cu_seqlens=None): """ u: r(B D L) delta: r(B D L) @@ -131,7 +133,10 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None for i in range(u.shape[2]): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if cu_seqlens is not None and i in cu_seqlens[1:-1].tolist(): + x = deltaB_u[:, :, i] + else: + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum('bdn,dn->bd', x, C) else: @@ -159,7 +164,7 @@ class MambaInnerFn(torch.autograd.Function): def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): + C_proj_bias=None, delta_softplus=True, cu_seqlens=None, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ @@ -177,13 +182,39 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) + + # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences + d_conv = 4 + if cu_seqlens is not None: + padded_x = x + count = 0 + for idx in cu_seqlens[1:-1].tolist(): + padded_idx = idx + count*(d_conv - 1) + padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) + count = count + 1 + x = padded_x + assert x.shape[2] == (d_conv - 1) * len(cu_seqlens[1:-1]) + z.shape[2] + conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) + + # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences + if cu_seqlens is not None: + mask = [] + for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): + mask.extend([True] * seq_len) + mask.extend([False] * (d_conv - 1)) + mask = mask[:-(d_conv - 1)] + assert conv1d_out.shape[2] == len(mask) + conv1d_out = conv1d_out[:, :, mask] + assert conv1d_out.shape[2] == z.shape[2] + # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) + ctx.is_variable_B = B is None ctx.is_variable_C = C is None ctx.B_proj_bias_is_None = B_proj_bias is None @@ -215,7 +246,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh if D is not None: D = D.contiguous() out, scan_intermediates, out_z = selective_scan_cuda.fwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus + conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens ) ctx.delta_softplus = delta_softplus ctx.out_proj_bias_is_None = out_proj_bias is None @@ -224,7 +255,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, - A, B, C, D, delta_bias, scan_intermediates, out) + A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @@ -232,17 +263,45 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh def backward(ctx, dout): # dout: (batch, seqlen, dim) (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, - conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors + conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) if dout.stride(-1) != 1: dout = dout.contiguous() + d_conv = 4 + + x_bak = x if ctx.checkpoint_lvl == 1: + # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences + if cu_seqlens is not None: + padded_x = x + count = 0 + for idx in cu_seqlens[1:-1].tolist(): + padded_idx = idx + count*(d_conv - 1) + padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) + count = count + 1 + x = padded_x + assert x.shape[2] == (d_conv - 1) * len(cu_seqlens[1:-1]) + z.shape[2] + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) + + # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences + if cu_seqlens is not None: + mask = [] + for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): + mask.extend([True] * seq_len) + mask.extend([False] * (d_conv - 1)) + mask = mask[:-(d_conv - 1)] + assert conv1d_out.shape[2] == len(mask) + conv1d_out = conv1d_out[:, :, mask] + assert conv1d_out.shape[2] == z.shape[2] + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) + x = x_bak + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) @@ -252,7 +311,8 @@ def backward(ctx, dout): dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, ctx.delta_softplus, - True # option to recompute out_z + True, # option to recompute out_z + cu_seqlens ) dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None @@ -294,18 +354,18 @@ def backward(ctx, dout): dout_proj_weight, dout_proj_bias, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, - dB_proj_bias, dC_proj_bias, None) + dB_proj_bias, dC_proj_bias, None, None) def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True + C_proj_bias=None, delta_softplus=True, cu_seqlens=None ): return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, - A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, cu_seqlens) def mamba_inner_ref( From a78a9ebb539051fe92bb328f6af4f27ac0552a1c Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Thu, 14 Mar 2024 16:20:19 +0800 Subject: [PATCH 02/27] add notes for variable length sequences --- mamba_ssm/modules/mamba_simple.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index a877558b..99666f92 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -119,6 +119,7 @@ def __init__( def forward(self, hidden_states, cu_seqlens=None, inference_params=None): """ hidden_states: (B, L, D) + cu_seqlens: one-dimensional tensor like flash-attn varlen API, only used for variable-length sequences and packing variable-length sequences into one, a.k.a., batch_size B=1 Returns: same shape as hidden_states """ batch, seqlen, dim = hidden_states.shape @@ -157,7 +158,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None): self.D.float(), delta_bias=self.dt_proj.bias.float(), delta_softplus=True, - cu_seqlens=cu_seqlens[0] if cu_seqlens is not None else None, + cu_seqlens=cu_seqlens, ) else: x, z = xz.chunk(2, dim=1) @@ -166,12 +167,12 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None): if cu_seqlens is not None: padded_x = x count = 0 - for idx in cu_seqlens[0][1:-1].tolist(): + for idx in cu_seqlens[1:-1].tolist(): padded_idx = idx + count*(self.d_conv - 1) padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], self.d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) count = count + 1 x = padded_x - assert x.shape[2] == (self.d_conv - 1) * len(cu_seqlens[0][1:-1]) + z.shape[2] + # assert x.shape[2] == (self.d_conv - 1) * len(cu_seqlens[1:-1]) + z.shape[2] # Compute short convolution if conv_state is not None: @@ -192,13 +193,13 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None): # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences if cu_seqlens is not None: mask = [] - for seq_len in (cu_seqlens[0][1:] - cu_seqlens[0][:-1]).tolist(): + for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): mask.extend([True] * seq_len) mask.extend([False] * (self.d_conv - 1)) mask = mask[:-(self.d_conv - 1)] - assert x.shape[2] == len(mask) + # assert x.shape[2] == len(mask) x = x[:, :, mask] - assert x.shape[2] == z.shape[2] + # assert x.shape[2] == z.shape[2] # We're careful here about the layout, to avoid extra transposes. # We want dt to have d as the slowest moving dimension @@ -222,7 +223,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None): delta_bias=self.dt_proj.bias.float(), delta_softplus=True, return_last_state=ssm_state is not None, - cu_seqlens=cu_seqlens[0] if cu_seqlens is not None else None, + cu_seqlens=cu_seqlens, ) if ssm_state is not None: y, last_state = y From e223353c495510ecc7b7488678840e268e6b3cb2 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Fri, 15 Mar 2024 17:32:13 +0800 Subject: [PATCH 03/27] fix typos --- csrc/selective_scan/selective_scan_bwd_kernel.cuh | 2 +- csrc/selective_scan/selective_scan_fwd_kernel.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index 1d7b1996..090a59b1 100644 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -136,7 +136,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; float dD_val = 0; float ddelta_bias_val = 0; - long *cu_seqlens = reinterpret_cast(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride + long *cu_seqlens = reinterpret_cast(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride; constexpr int kChunkSize = kNThreads * kNItems; u += (params.n_chunks - 1) * kChunkSize; diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 131ca720..8ecf126d 100644 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -107,7 +107,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; - long *cu_seqlens = reinterpret_cast(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride + long *cu_seqlens = reinterpret_cast(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride; float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { From 595545062f402941ea3b43092ded3c85042c4d08 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Mon, 18 Mar 2024 15:43:39 +0800 Subject: [PATCH 04/27] fix typos --- csrc/selective_scan/selective_scan_bwd_kernel.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index 090a59b1..b204ab3f 100644 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -245,7 +245,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { if constexpr (!kIsComplex) { #pragma unroll for (int i = 0; i < kNItems; ++i) { - const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + float delta_a_exp = exp2f(delta_vals[i] * A_scaled); // Reset A bar for cumulative sequences (Real) int left = 1; From c2d5b88d2f57c48e8f588c3ab2994c0da695e6cc Mon Sep 17 00:00:00 2001 From: Dmovic <944388576@qq.com> Date: Mon, 18 Mar 2024 08:24:59 +0000 Subject: [PATCH 05/27] fix typos --- csrc/selective_scan/selective_scan_bwd_kernel.cuh | 2 ++ csrc/selective_scan/selective_scan_fwd_kernel.cuh | 2 ++ 2 files changed, 4 insertions(+) diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index b204ab3f..0b159871 100644 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -253,6 +253,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { while (left <= right) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { delta_a_exp = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else { @@ -356,6 +357,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { delta_a_exp.real_ = 0.f; delta_a_exp.imag_ = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else { diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 8ecf126d..42a95b9d 100644 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -223,6 +223,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { while (left <= right) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { thread_data[i].x = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else { @@ -248,6 +249,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { thread_data[i].x = 0.f; thread_data[i].y = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else { From db0dd09aeb13cd592be1c24068ae1704247f3d19 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Mon, 18 Mar 2024 16:36:35 +0800 Subject: [PATCH 06/27] fix typos --- csrc/selective_scan/selective_scan_bwd_kernel.cuh | 2 ++ csrc/selective_scan/selective_scan_fwd_kernel.cuh | 2 ++ 2 files changed, 4 insertions(+) diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index b204ab3f..0b159871 100644 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -253,6 +253,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { while (left <= right) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { delta_a_exp = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else { @@ -356,6 +357,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { delta_a_exp.real_ = 0.f; delta_a_exp.imag_ = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else { diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 8ecf126d..42a95b9d 100644 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -223,6 +223,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { while (left <= right) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { thread_data[i].x = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else { @@ -248,6 +249,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { thread_data[i].x = 0.f; thread_data[i].y = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else { From e7774aaa8df27a36af75c8a171d77962502e61f8 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Mon, 18 Mar 2024 23:29:55 +0800 Subject: [PATCH 07/27] refine cu_seqlens implementation --- mamba_ssm/modules/mamba_simple.py | 4 +-- mamba_ssm/ops/selective_scan_interface.py | 43 +++++++++++++++-------- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 303f8d33..c56e587c 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -159,6 +159,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None): delta_bias=self.dt_proj.bias.float(), delta_softplus=True, cu_seqlens=cu_seqlens, + d_conv=self.d_conv, ) else: x, z = xz.chunk(2, dim=1) @@ -172,7 +173,6 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None): padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], self.d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) count = count + 1 x = padded_x - # assert x.shape[2] == (self.d_conv - 1) * len(cu_seqlens[1:-1]) + z.shape[2] # Compute short convolution if conv_state is not None: @@ -197,9 +197,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None): mask.extend([True] * seq_len) mask.extend([False] * (self.d_conv - 1)) mask = mask[:-(self.d_conv - 1)] - # assert x.shape[2] == len(mask) x = x[:, :, mask] - # assert x.shape[2] == z.shape[2] # We're careful here about the layout, to avoid extra transposes. # We want dt to have d as the slowest moving dimension diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index 874a0917..d8516d44 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -169,7 +169,7 @@ class MambaInnerFn(torch.autograd.Function): def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, cu_seqlens=None, checkpoint_lvl=1): + C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ @@ -190,7 +190,6 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh x, z = xz.chunk(2, dim=1) # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences - d_conv = 4 if cu_seqlens is not None: padded_x = x count = 0 @@ -199,7 +198,6 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) count = count + 1 x = padded_x - assert x.shape[2] == (d_conv - 1) * len(cu_seqlens[1:-1]) + z.shape[2] conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True) @@ -211,9 +209,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh mask.extend([True] * seq_len) mask.extend([False] * (d_conv - 1)) mask = mask[:-(d_conv - 1)] - assert conv1d_out.shape[2] == len(mask) conv1d_out = conv1d_out[:, :, mask] - assert conv1d_out.shape[2] == z.shape[2] # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension @@ -261,7 +257,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, - A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens) + A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, d_conv) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @@ -270,14 +266,13 @@ def backward(ctx, dout): # dout: (batch, seqlen, dim) assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, - conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens) = ctx.saved_tensors + conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, d_conv) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) if dout.stride(-1) != 1: dout = dout.contiguous() - d_conv = 4 x_bak = x if ctx.checkpoint_lvl == 1: @@ -290,7 +285,6 @@ def backward(ctx, dout): padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) count = count + 1 x = padded_x - assert x.shape[2] == (d_conv - 1) * len(cu_seqlens[1:-1]) + z.shape[2] conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True) @@ -301,9 +295,7 @@ def backward(ctx, dout): mask.extend([True] * seq_len) mask.extend([False] * (d_conv - 1)) mask = mask[:-(d_conv - 1)] - assert conv1d_out.shape[2] == len(mask) conv1d_out = conv1d_out[:, :, mask] - assert conv1d_out.shape[2] == z.shape[2] delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) @@ -361,32 +353,53 @@ def backward(ctx, dout): dout_proj_weight, dout_proj_bias, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, - dB_proj_bias, dC_proj_bias, None, None) + dB_proj_bias, dC_proj_bias, None, None, None) def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, cu_seqlens=None + C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None ): return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, - A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, cu_seqlens) + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, cu_seqlens, d_conv) def mamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True + C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None ): assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) + + # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences + if cu_seqlens is not None: + padded_x = x + count = 0 + for idx in cu_seqlens[1:-1].tolist(): + padded_idx = idx + count*(d_conv - 1) + padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) + count = count + 1 + x = padded_x + x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu") + + # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences + if cu_seqlens is not None: + mask = [] + for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): + mask.extend([True] * seq_len) + mask.extend([False] * (d_conv - 1)) + mask = mask[:-(d_conv - 1)] + x = x[:, :, mask] + # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. From f357c44af53c47a889bbd40dfd5b9e86e3882021 Mon Sep 17 00:00:00 2001 From: Dmovic <944388576@qq.com> Date: Tue, 19 Mar 2024 03:43:40 +0000 Subject: [PATCH 08/27] add unit test for variable length --- tests/ops/test_selective_scan_var_len.py | 156 +++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 tests/ops/test_selective_scan_var_len.py diff --git a/tests/ops/test_selective_scan_var_len.py b/tests/ops/test_selective_scan_var_len.py new file mode 100644 index 00000000..cdf3d41f --- /dev/null +++ b/tests/ops/test_selective_scan_var_len.py @@ -0,0 +1,156 @@ +# Copyright (C) 2023, Tri Dao. + +import math + +import torch +import torch.nn.functional as F +import pytest + +from einops import rearrange + +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref +from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref + + +@pytest.mark.parametrize('wtype', [torch.float32]) +@pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("return_last_state", [True]) +@pytest.mark.parametrize('has_delta_bias', [True]) +@pytest.mark.parametrize('delta_softplus', [True]) +@pytest.mark.parametrize('has_z', [True]) +@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize("varBC_groups", [1, 2]) +@pytest.mark.parametrize("is_variable_C", [True]) +@pytest.mark.parametrize("is_variable_B", [True]) +@pytest.mark.parametrize("seq_num", [1, 2, 3, 4, 5, 6, 7]) +def test_selective_scan_variable_length(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, + delta_softplus, return_last_state, seqlen, itype, wtype, seq_num): + is_variable_B = True + if varBC_groups > 1 and (not is_variable_B or not is_variable_C): + pytest.skip() # This config is not applicable + device = 'cuda' + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + if has_z: # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + batch_size = 1 + dim = 4 + dstate = 1 + is_complex = wtype == torch.complex64 + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() + if not is_variable_B: + B_shape = (dim, dstate) + elif varBC_groups == 1: + B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, + requires_grad=True) + if not is_variable_C: + C_shape = (dim, dstate) + elif varBC_groups == 1: + C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, + requires_grad=True) + if has_D: + D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + else: + D = None + if has_z: + z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + else: + z = None + if has_delta_bias: + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() + else: + delta_bias = None + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() + + A_ref = A.detach().clone().requires_grad_() + B_ref = B.detach().clone().requires_grad_() + C_ref = C.detach().clone().requires_grad_() + D_ref = D.detach().clone().requires_grad_() if D is not None else None + z_ref = z.detach().clone().requires_grad_() if z is not None else None + u_ref = u.detach().clone().requires_grad_() + delta_ref = delta.detach().clone().requires_grad_() + delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None + start_indexes = torch.arange(0, seqlen, seqlen / seq_num, dtype=torch.int64).cuda() + + out, *rest = selective_scan_fn( + u, delta, A, B, C, D, z=z, + delta_bias=delta_bias, delta_softplus=delta_softplus, + return_last_state=return_last_state, cu_seqlens=start_indexes + ) + if return_last_state: + state = rest[0] + out_ref, *rest = selective_scan_ref( + u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, + delta_bias=delta_bias_ref, delta_softplus=delta_softplus, + return_last_state=return_last_state, cu_seqlens=start_indexes + ) + if return_last_state: + state_ref = rest[0] + # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + # dt_u = delta * u + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + if return_last_state: + print(f'State max diff: {(state - state_ref).abs().max().item()}') + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + + g = torch.randn_like(out) + out_ref.backward(g) + out.backward(g) + + print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}') + print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}') + print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') + print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') + print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') + if has_D: + print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') + if has_z: + print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}') + if has_delta_bias: + print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') + + assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) + assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) + assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) + assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, + atol=atolw if not is_variable_B else atol) + assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, + atol=atolw if not is_variable_C else atol) + if has_D: + assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) + if has_z: + assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) + if has_delta_bias: + assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) + +if __name__ == "__main__" : + wtype = torch.float32 + itype = torch.float32 + seqlen = 8 + return_last_state = True + has_delta_bias = True + delta_softplus = False + has_z = True + has_D = True + varBC_groups = 1 + is_variable_C = True + is_variable_B = True + seq_num = 4 + test_selective_scan_variable_length(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, + delta_softplus, return_last_state, seqlen, itype, wtype, seq_num) From 6b98161c17c4125a12ed959f32685d622fc79adf Mon Sep 17 00:00:00 2001 From: Dmovic <944388576@qq.com> Date: Tue, 19 Mar 2024 04:43:48 +0000 Subject: [PATCH 09/27] update unit test --- tests/ops/test_selective_scan_var_len.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/ops/test_selective_scan_var_len.py b/tests/ops/test_selective_scan_var_len.py index cdf3d41f..199d4610 100644 --- a/tests/ops/test_selective_scan_var_len.py +++ b/tests/ops/test_selective_scan_var_len.py @@ -9,7 +9,6 @@ from einops import rearrange from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref -from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref @pytest.mark.parametrize('wtype', [torch.float32]) @@ -138,19 +137,3 @@ def test_selective_scan_variable_length(is_variable_B, is_variable_C, varBC_grou assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) if has_delta_bias: assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) - -if __name__ == "__main__" : - wtype = torch.float32 - itype = torch.float32 - seqlen = 8 - return_last_state = True - has_delta_bias = True - delta_softplus = False - has_z = True - has_D = True - varBC_groups = 1 - is_variable_C = True - is_variable_B = True - seq_num = 4 - test_selective_scan_variable_length(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, - delta_softplus, return_last_state, seqlen, itype, wtype, seq_num) From e4af927d7accfddf9ca0ec0d96b8b812c49324ac Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Tue, 19 Mar 2024 17:49:49 +0800 Subject: [PATCH 10/27] fix typos --- mamba_ssm/ops/selective_scan_interface.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index d8516d44..59bf9a3c 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -257,7 +257,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, - A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, d_conv) + A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, torch.tensor([d_conv])) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @@ -267,6 +267,7 @@ def backward(ctx, dout): assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, d_conv) = ctx.saved_tensors + d_conv = d_conv.item() L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) From 4221d48e3766a56bff01ecd52f607b6a18ed7e3d Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Mon, 25 Mar 2024 11:22:18 +0800 Subject: [PATCH 11/27] update selective scan --- mamba_ssm/ops/selective_scan_interface.py | 122 +++++----------------- 1 file changed, 26 insertions(+), 96 deletions(-) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index 59bf9a3c..c3596bfe 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -20,7 +20,7 @@ class SelectiveScanFn(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False, cu_seqlens=None): + return_last_state=False): if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: @@ -39,26 +39,26 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens) + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if not ctx.has_z: - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x, cu_seqlens) + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out if not return_last_state else (out, last_state) else: - ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out, cu_seqlens) + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) out_z = rest[0] return out_z if not return_last_state else (out_z, last_state) @staticmethod def backward(ctx, dout, *args): if not ctx.has_z: - u, delta, A, B, C, D, delta_bias, x, cu_seqlens = ctx.saved_tensors + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors z = None out = None else: - u, delta, A, B, C, D, z, delta_bias, x, out, cu_seqlens = ctx.saved_tensors + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the @@ -66,8 +66,7 @@ def backward(ctx, dout, *args): # Here we just pass in None and dz will be allocated in the C++ code. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, - False, # option to recompute out_z, not used here - cu_seqlens + False # option to recompute out_z, not used here ) dz = rest[0] if ctx.has_z else None dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB @@ -77,21 +76,20 @@ def backward(ctx, dout, *args): dz, ddelta_bias if delta_bias is not None else None, None, - None, None) def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False, cu_seqlens=None): + return_last_state=False): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, cu_seqlens) + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False, cu_seqlens=None): + return_last_state=False): """ u: r(B D L) delta: r(B D L) @@ -138,10 +136,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None for i in range(u.shape[2]): - if cu_seqlens is not None and i in cu_seqlens[1:-1].tolist(): - x = deltaB_u[:, :, i] - else: - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum('bdn,dn->bd', x, C) else: @@ -169,7 +164,7 @@ class MambaInnerFn(torch.autograd.Function): def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None, checkpoint_lvl=1): + C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ @@ -188,35 +183,15 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) - - # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences - if cu_seqlens is not None: - padded_x = x - count = 0 - for idx in cu_seqlens[1:-1].tolist(): - padded_idx = idx + count*(d_conv - 1) - padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) - count = count + 1 - x = padded_x - conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True) - - # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences - if cu_seqlens is not None: - mask = [] - for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): - mask.extend([True] * seq_len) - mask.extend([False] * (d_conv - 1)) - mask = mask[:-(d_conv - 1)] - conv1d_out = conv1d_out[:, :, mask] - + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( + x, conv1d_weight, conv1d_bias, None, None, None, True + ) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) - ctx.is_variable_B = B is None ctx.is_variable_C = C is None ctx.B_proj_bias_is_None = B_proj_bias is None @@ -248,7 +223,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh if D is not None: D = D.contiguous() out, scan_intermediates, out_z = selective_scan_cuda.fwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens + conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus ) ctx.delta_softplus = delta_softplus ctx.out_proj_bias_is_None = out_proj_bias is None @@ -257,7 +232,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, - A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, torch.tensor([d_conv])) + A, B, C, D, delta_bias, scan_intermediates, out) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @@ -266,42 +241,19 @@ def backward(ctx, dout): # dout: (batch, seqlen, dim) assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, - conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, d_conv) = ctx.saved_tensors - d_conv = d_conv.item() + conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) if dout.stride(-1) != 1: dout = dout.contiguous() - - x_bak = x if ctx.checkpoint_lvl == 1: - # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences - if cu_seqlens is not None: - padded_x = x - count = 0 - for idx in cu_seqlens[1:-1].tolist(): - padded_idx = idx + count*(d_conv - 1) - padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) - count = count + 1 - x = padded_x - - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True) - - # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences - if cu_seqlens is not None: - mask = [] - for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): - mask.extend([True] * seq_len) - mask.extend([False] * (d_conv - 1)) - mask = mask[:-(d_conv - 1)] - conv1d_out = conv1d_out[:, :, mask] - + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( + x, conv1d_weight, conv1d_bias, None, None, None, True + ) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) - x = x_bak - # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) @@ -311,8 +263,7 @@ def backward(ctx, dout): dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, ctx.delta_softplus, - True, # option to recompute out_z - cu_seqlens + True # option to recompute out_z ) dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None @@ -354,53 +305,32 @@ def backward(ctx, dout): dout_proj_weight, dout_proj_bias, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, - dB_proj_bias, dC_proj_bias, None, None, None) + dB_proj_bias, dC_proj_bias, None) def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None + C_proj_bias=None, delta_softplus=True ): return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, - A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, cu_seqlens, d_conv) + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) def mamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None + C_proj_bias=None, delta_softplus=True ): assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) - - # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences - if cu_seqlens is not None: - padded_x = x - count = 0 - for idx in cu_seqlens[1:-1].tolist(): - padded_idx = idx + count*(d_conv - 1) - padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) - count = count + 1 - x = padded_x - x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu") - - # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences - if cu_seqlens is not None: - mask = [] - for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): - mask.extend([True] * seq_len) - mask.extend([False] * (d_conv - 1)) - mask = mask[:-(d_conv - 1)] - x = x[:, :, mask] - # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. From 934c0e679d791d0e8cec898b50cc0ee9d24af242 Mon Sep 17 00:00:00 2001 From: wangzerui Date: Mon, 25 Mar 2024 11:37:55 +0800 Subject: [PATCH 12/27] Add logic for variable-length sequences --- mamba_ssm/ops/selective_scan_interface.py | 124 +++++++++++++++++----- 1 file changed, 97 insertions(+), 27 deletions(-) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index c3596bfe..af28c417 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -20,7 +20,7 @@ class SelectiveScanFn(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, cu_seqlens=None): if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: @@ -39,26 +39,26 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if not ctx.has_z: - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x, cu_seqlens) return out if not return_last_state else (out, last_state) else: - ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out, cu_seqlens) out_z = rest[0] return out_z if not return_last_state else (out_z, last_state) @staticmethod def backward(ctx, dout, *args): if not ctx.has_z: - u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + u, delta, A, B, C, D, delta_bias, x, cu_seqlens = ctx.saved_tensors z = None out = None else: - u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + u, delta, A, B, C, D, z, delta_bias, x, out, cu_seqlens = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the @@ -66,7 +66,8 @@ def backward(ctx, dout, *args): # Here we just pass in None and dz will be allocated in the C++ code. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, - False # option to recompute out_z, not used here + False, # option to recompute out_z, not used here + cu_seqlens ) dz = rest[0] if ctx.has_z else None dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB @@ -76,20 +77,21 @@ def backward(ctx, dout, *args): dz, ddelta_bias if delta_bias is not None else None, None, + None, None) def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, cu_seqlens=None): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, cu_seqlens) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, cu_seqlens=None): """ u: r(B D L) delta: r(B D L) @@ -136,7 +138,10 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None for i in range(u.shape[2]): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if cu_seqlens is not None and i in cu_seqlens[1:-1].tolist(): + x = deltaB_u[:, :, i] + else: + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum('bdn,dn->bd', x, C) else: @@ -164,7 +169,7 @@ class MambaInnerFn(torch.autograd.Function): def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): + C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ @@ -183,15 +188,35 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) + + # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences + if cu_seqlens is not None: + padded_x = x + count = 0 + for idx in cu_seqlens[1:-1].tolist(): + padded_idx = idx + count*(d_conv - 1) + padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) + count = count + 1 + x = padded_x + conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( - x, conv1d_weight, conv1d_bias, None, None, None, True - ) + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True) + + # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences + if cu_seqlens is not None: + mask = [] + for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): + mask.extend([True] * seq_len) + mask.extend([False] * (d_conv - 1)) + mask = mask[:-(d_conv - 1)] + conv1d_out = conv1d_out[:, :, mask] + # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) + ctx.is_variable_B = B is None ctx.is_variable_C = C is None ctx.B_proj_bias_is_None = B_proj_bias is None @@ -223,7 +248,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh if D is not None: D = D.contiguous() out, scan_intermediates, out_z = selective_scan_cuda.fwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus + conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens ) ctx.delta_softplus = delta_softplus ctx.out_proj_bias_is_None = out_proj_bias is None @@ -232,7 +257,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, - A, B, C, D, delta_bias, scan_intermediates, out) + A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, torch.tensor([d_conv])) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @@ -241,19 +266,42 @@ def backward(ctx, dout): # dout: (batch, seqlen, dim) assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, - conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors + conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, d_conv) = ctx.saved_tensors + d_conv = d_conv.item() L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) if dout.stride(-1) != 1: dout = dout.contiguous() + + x_bak = x if ctx.checkpoint_lvl == 1: - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( - x, conv1d_weight, conv1d_bias, None, None, None, True - ) + # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences + if cu_seqlens is not None: + padded_x = x + count = 0 + for idx in cu_seqlens[1:-1].tolist(): + padded_idx = idx + count*(d_conv - 1) + padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) + count = count + 1 + x = padded_x + + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True) + + # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences + if cu_seqlens is not None: + mask = [] + for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): + mask.extend([True] * seq_len) + mask.extend([False] * (d_conv - 1)) + mask = mask[:-(d_conv - 1)] + conv1d_out = conv1d_out[:, :, mask] + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) + x = x_bak + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) @@ -263,7 +311,8 @@ def backward(ctx, dout): dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, ctx.delta_softplus, - True # option to recompute out_z + True, # option to recompute out_z + cu_seqlens ) dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None @@ -305,32 +354,53 @@ def backward(ctx, dout): dout_proj_weight, dout_proj_bias, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, - dB_proj_bias, dC_proj_bias, None) + dB_proj_bias, dC_proj_bias, None, None, None) def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True + C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None ): return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, - A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, cu_seqlens, d_conv) def mamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True + C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None ): assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) + + # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences + if cu_seqlens is not None: + padded_x = x + count = 0 + for idx in cu_seqlens[1:-1].tolist(): + padded_idx = idx + count*(d_conv - 1) + padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) + count = count + 1 + x = padded_x + x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu") + + # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences + if cu_seqlens is not None: + mask = [] + for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): + mask.extend([True] * seq_len) + mask.extend([False] * (d_conv - 1)) + mask = mask[:-(d_conv - 1)] + x = x[:, :, mask] + # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. @@ -354,4 +424,4 @@ def mamba_inner_ref( else: C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) - return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) + return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) \ No newline at end of file From f6bb7e26f8359d0bdd22e43c14d17b5e44b39600 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Fri, 26 Apr 2024 15:22:21 +0800 Subject: [PATCH 13/27] add example test to prove the mathematical equivalence of cu_seqlens for mamba block --- .../ops/test_mamba_cu_seqlens_equivalence.py | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 tests/ops/test_mamba_cu_seqlens_equivalence.py diff --git a/tests/ops/test_mamba_cu_seqlens_equivalence.py b/tests/ops/test_mamba_cu_seqlens_equivalence.py new file mode 100644 index 00000000..2a426f93 --- /dev/null +++ b/tests/ops/test_mamba_cu_seqlens_equivalence.py @@ -0,0 +1,91 @@ +import random +import torch + +from mamba_ssm.modules.mamba_simple import Mamba + + +''' +unpack function: convert packed_hidden_states (batch_size=1) to hidden_states +''' +def unpack(packed_hidden_states, cu_seqlens): + batch_size = cu_seqlens.shape[0] - 1 + seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + hidden_dim = packed_hidden_states.shape[2] + hidden_states = torch.zeros(batch_size, seq_len, hidden_dim, dtype=packed_hidden_states.dtype, device=packed_hidden_states.device) + for i in range(batch_size): + hidden_states[i, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[:, cu_seqlens[i] : cu_seqlens[i + 1], :] + return hidden_states + + +''' +pack function: convert hidden_states to packed_hidden_states (batch_size=1) +''' +def pack(hidden_states, cu_seqlens): + batch_size, seq_len, hidden_dim = hidden_states.shape + seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] + seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) + indices_3d = ( + torch.arange(seq_len, device=hidden_states.device) + .unsqueeze(0) + .unsqueeze(2) + .repeat(batch_size, 1, hidden_dim) + ) + mask_3d = indices_3d < seq_len_list_3d + packed_hidden_states = hidden_states[mask_3d].view(-1, hidden_dim) + return packed_hidden_states + + +''' +Generate random cu_seqlens for testing +''' +def generate_random_cu_seqlens(seq_len, batch_size): + if batch_size > 1: + ret = sorted(random.sample(range(1, seq_len), batch_size - 1)) + else: + ret = [] + cu_seqlens = [0] + ret + [seq_len] + assert batch_size == len(cu_seqlens) - 1 + return cu_seqlens + + +def main(): + # config tested with A100 + hidden_dim = 2048 + seq_len = 1024 + batch_size = 8 + device='cuda' + + # Generate random cu_seqlens for testing + cu_seqlens = generate_random_cu_seqlens(seq_len, batch_size) + cu_seqlens = torch.tensor(cu_seqlens, device=device) + print(f'Generate random cu_seqlens = {cu_seqlens.tolist()}') + + # Generate packed_hidden_states with random values for testing + # packed_hidden_states (batch_size=1) should be forwarded with cu_seqlens + hidden_states_list = [torch.randn(l, hidden_dim, device=device) for l in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()] + packed_hidden_states = torch.cat(hidden_states_list, dim=0).unsqueeze(0) + # hidden_states should be forwarded without cu_seqlens + hidden_states = unpack(packed_hidden_states, cu_seqlens) + + # Check: sum of seq_len of item in hidden_states_list should be equal to seq_len of packed_hidden_states + assert sum([hs.shape[0] for hs in hidden_states_list]) == packed_hidden_states.shape[1] + # Check: max of seq_len of item in hidden_states_list should be equal to seq_len of hidden_states + assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1] + + # creat one simple mamba block + mamba = Mamba(hidden_dim).to(device) + + # reference output for forwardding hidden_states + out_ref = mamba(hidden_states) + out_ref = pack(out_ref, cu_seqlens).unsqueeze(0) + + # output for forwardding packed_hidden_states with cu_seqlens + out = mamba(packed_hidden_states, cu_seqlens) + + # Testing the max/mean diff + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + + +if __name__ == "__main__": + main() \ No newline at end of file From bffcd97869bbf05f9c4394d537d51c221c47a82b Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Fri, 26 Apr 2024 16:15:57 +0800 Subject: [PATCH 14/27] fix typos --- tests/ops/test_selective_scan_var_len.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ops/test_selective_scan_var_len.py b/tests/ops/test_selective_scan_var_len.py index 199d4610..37407b42 100644 --- a/tests/ops/test_selective_scan_var_len.py +++ b/tests/ops/test_selective_scan_var_len.py @@ -82,19 +82,19 @@ def test_selective_scan_variable_length(is_variable_B, is_variable_C, varBC_grou u_ref = u.detach().clone().requires_grad_() delta_ref = delta.detach().clone().requires_grad_() delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None - start_indexes = torch.arange(0, seqlen, seqlen / seq_num, dtype=torch.int64).cuda() + cu_seqlens = torch.arange(0, seqlen, seqlen / seq_num, dtype=torch.int64).cuda() out, *rest = selective_scan_fn( u, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=delta_softplus, - return_last_state=return_last_state, cu_seqlens=start_indexes + return_last_state=return_last_state, cu_seqlens=cu_seqlens ) if return_last_state: state = rest[0] out_ref, *rest = selective_scan_ref( u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, delta_bias=delta_bias_ref, delta_softplus=delta_softplus, - return_last_state=return_last_state, cu_seqlens=start_indexes + return_last_state=return_last_state, cu_seqlens=cu_seqlens ) if return_last_state: state_ref = rest[0] From e3cab98f71ee3e8d4b703b244750da20c6183dee Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Fri, 26 Apr 2024 16:36:45 +0800 Subject: [PATCH 15/27] add cu_seqlens support for MixerModel --- mamba_ssm/models/mixer_seq_simple.py | 8 ++++---- mamba_ssm/modules/mamba_simple.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index cd224738..9a36de7e 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -148,12 +148,12 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs) for i, layer in enumerate(self.layers) } - def forward(self, input_ids, inference_params=None): + def forward(self, input_ids, cu_seqlens=None, inference_params=None): hidden_states = self.embedding(input_ids) residual = None for layer in self.layers: hidden_states, residual = layer( - hidden_states, residual, inference_params=inference_params + hidden_states, residual, cu_seqlens=cu_seqlens, inference_params=inference_params ) if not self.fused_add_norm: residual = (hidden_states + residual) if residual is not None else hidden_states @@ -226,12 +226,12 @@ def tie_weights(self): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): + def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, cu_seqlens=None): """ "position_ids" is just to be compatible with Transformer generation. We don't use it. num_last_tokens: if > 0, only return the logits for the last n tokens """ - hidden_states = self.backbone(input_ids, inference_params=inference_params) + hidden_states = self.backbone(input_ids, inference_params=inference_params, cu_seqlens=cu_seqlens) if num_last_tokens > 0: hidden_states = hidden_states[:, -num_last_tokens:] lm_logits = self.lm_head(hidden_states) diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index c56e587c..6fb6327f 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -347,7 +347,7 @@ def __init__( ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None + self, hidden_states: Tensor, residual: Optional[Tensor] = None, cu_seqlens=None, inference_params=None ): r"""Pass the input through the encoder layer. @@ -371,7 +371,7 @@ def forward( residual_in_fp32=self.residual_in_fp32, eps=self.norm.eps, ) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) + hidden_states = self.mixer(hidden_states, cu_seqlens=cu_seqlens, inference_params=inference_params) return hidden_states, residual def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): From 2f01ededae661c254f541d0a0bbc2fab6ed3279d Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Tue, 30 Apr 2024 11:22:53 +0800 Subject: [PATCH 16/27] code refine for tests --- mamba_ssm/modules/mamba_simple.py | 2 +- mamba_ssm/ops/selective_scan_interface.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 6fb6327f..fc3029a0 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -159,7 +159,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None): delta_bias=self.dt_proj.bias.float(), delta_softplus=True, cu_seqlens=cu_seqlens, - d_conv=self.d_conv, + d_conv=torch.tensor([self.d_conv], device=cu_seqlens.device), ) else: x, z = xz.chunk(2, dim=1) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index af28c417..d641380e 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -257,7 +257,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, - A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, torch.tensor([d_conv])) + A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, d_conv) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @@ -267,7 +267,6 @@ def backward(ctx, dout): assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, d_conv) = ctx.saved_tensors - d_conv = d_conv.item() L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) From f0a650893268c94c2b903bd8870586ec7c84bf5e Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Tue, 30 Apr 2024 16:27:37 +0800 Subject: [PATCH 17/27] refine code for tests --- mamba_ssm/modules/mamba_simple.py | 2 +- .../ops/test_mamba_cu_seqlens_equivalence.py | 10 ++++++++ tests/ops/test_selective_scan_var_len.py | 24 +++++++++++++------ 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index fc3029a0..fdfac1e1 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -159,7 +159,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None): delta_bias=self.dt_proj.bias.float(), delta_softplus=True, cu_seqlens=cu_seqlens, - d_conv=torch.tensor([self.d_conv], device=cu_seqlens.device), + d_conv=torch.tensor(self.d_conv) ) else: x, z = xz.chunk(2, dim=1) diff --git a/tests/ops/test_mamba_cu_seqlens_equivalence.py b/tests/ops/test_mamba_cu_seqlens_equivalence.py index 2a426f93..20106730 100644 --- a/tests/ops/test_mamba_cu_seqlens_equivalence.py +++ b/tests/ops/test_mamba_cu_seqlens_equivalence.py @@ -55,6 +55,15 @@ def main(): batch_size = 8 device='cuda' + itype = torch.float32 + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # Generate random cu_seqlens for testing cu_seqlens = generate_random_cu_seqlens(seq_len, batch_size) cu_seqlens = torch.tensor(cu_seqlens, device=device) @@ -85,6 +94,7 @@ def main(): # Testing the max/mean diff print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) if __name__ == "__main__": diff --git a/tests/ops/test_selective_scan_var_len.py b/tests/ops/test_selective_scan_var_len.py index 37407b42..9504c2c2 100644 --- a/tests/ops/test_selective_scan_var_len.py +++ b/tests/ops/test_selective_scan_var_len.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Tri Dao. import math +import random import torch import torch.nn.functional as F @@ -10,6 +11,18 @@ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref +''' +Generate random cu_seqlens for testing +''' +def generate_random_cu_seqlens(seq_len): + batch_size = random.randint(1,seq_len) + if batch_size > 1: + ret = sorted(random.sample(range(1, seq_len), batch_size - 1)) + else: + ret = [] + cu_seqlens = [0] + ret + [seq_len] + assert batch_size == len(cu_seqlens) - 1 + return [0] + ret + [seq_len] @pytest.mark.parametrize('wtype', [torch.float32]) @pytest.mark.parametrize('itype', [torch.float32]) @@ -22,10 +35,8 @@ @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) -@pytest.mark.parametrize("seq_num", [1, 2, 3, 4, 5, 6, 7]) def test_selective_scan_variable_length(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, - delta_softplus, return_last_state, seqlen, itype, wtype, seq_num): - is_variable_B = True + delta_softplus, return_last_state, seqlen, itype, wtype): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable device = 'cuda' @@ -39,8 +50,8 @@ def test_selective_scan_variable_length(is_variable_B, is_variable_C, varBC_grou # set seed torch.random.manual_seed(0) batch_size = 1 - dim = 4 - dstate = 1 + dim = 768 + dstate = 8 is_complex = wtype == torch.complex64 A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() if not is_variable_B: @@ -73,7 +84,6 @@ def test_selective_scan_variable_length(is_variable_B, is_variable_C, varBC_grou delta_bias = None u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() - A_ref = A.detach().clone().requires_grad_() B_ref = B.detach().clone().requires_grad_() C_ref = C.detach().clone().requires_grad_() @@ -82,7 +92,7 @@ def test_selective_scan_variable_length(is_variable_B, is_variable_C, varBC_grou u_ref = u.detach().clone().requires_grad_() delta_ref = delta.detach().clone().requires_grad_() delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None - cu_seqlens = torch.arange(0, seqlen, seqlen / seq_num, dtype=torch.int64).cuda() + cu_seqlens = torch.tensor(generate_random_cu_seqlens(seqlen)).cuda() out, *rest = selective_scan_fn( u, delta, A, B, C, D, z=z, From 623d246f95b922178afbebbc26a0d059d5ed7c75 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Tue, 30 Apr 2024 16:33:30 +0800 Subject: [PATCH 18/27] update API notes --- mamba_ssm/modules/mamba_simple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index fdfac1e1..0862763c 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -119,7 +119,7 @@ def __init__( def forward(self, hidden_states, cu_seqlens=None, inference_params=None): """ hidden_states: (B, L, D) - cu_seqlens: one-dimensional tensor like flash-attn varlen API, only used for variable-length sequences and packing variable-length sequences into one, a.k.a., batch_size B=1 + cu_seqlens: one-dimensional tensor representing cumulative start indexes of packed sequence, a.k.a., B=1 Returns: same shape as hidden_states """ batch, seqlen, dim = hidden_states.shape From ef3f760bc409c02c2e61bb57786da285f750af9e Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Tue, 30 Apr 2024 17:08:39 +0800 Subject: [PATCH 19/27] update test code --- tests/ops/test_mamba_cu_seqlens_equivalence.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/ops/test_mamba_cu_seqlens_equivalence.py b/tests/ops/test_mamba_cu_seqlens_equivalence.py index 20106730..01d066ef 100644 --- a/tests/ops/test_mamba_cu_seqlens_equivalence.py +++ b/tests/ops/test_mamba_cu_seqlens_equivalence.py @@ -82,7 +82,13 @@ def main(): assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1] # creat one simple mamba block - mamba = Mamba(hidden_dim).to(device) + mamba = Mamba( + # This module uses roughly 3 * expand * d_model^2 parameters + d_model=hidden_dim, # Model dimension d_model + d_state=16, # SSM state expansion factor + d_conv=4, # Local convolution width + expand=2, # Block expansion factor + ).to(device) # reference output for forwardding hidden_states out_ref = mamba(hidden_states) From 2d27ccc555ae958444cdf655c395440501eb10f6 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Thu, 6 Jun 2024 14:45:43 +0800 Subject: [PATCH 20/27] fix conflicts with latest main branch --- mamba_ssm/models/mixer_seq_simple.py | 4 +- mamba_ssm/modules/block.py | 6 +-- mamba_ssm/modules/mamba_simple.py | 56 ++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 4 deletions(-) diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index 98a05b58..287b6535 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -16,6 +16,7 @@ from mamba_ssm.modules.mamba2 import Mamba2 from mamba_ssm.modules.mha import MHA from mamba_ssm.modules.mlp import GatedMLP +from mamba_ssm.modules.mamba_simple import Block as Block_Mamba1 from mamba_ssm.modules.block import Block from mamba_ssm.utils.generation import GenerationMixin from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf @@ -70,7 +71,8 @@ def create_block( mlp_cls = partial( GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs ) - block = Block( + block_cls = Block if ssm_layer == "Mamba2" else Block_Mamba1 + block = block_cls( d_model, mixer_cls, mlp_cls, diff --git a/mamba_ssm/modules/block.py b/mamba_ssm/modules/block.py index ffd03034..5fdaff05 100644 --- a/mamba_ssm/modules/block.py +++ b/mamba_ssm/modules/block.py @@ -40,7 +40,7 @@ def __init__( ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, cu_seqlens=None, inference_params=None, **mixer_kwargs + self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs ): r"""Pass the input through the encoder layer. @@ -64,7 +64,7 @@ def forward( eps=self.norm.eps, is_rms_norm=isinstance(self.norm, RMSNorm) ) - hidden_states = self.mixer(hidden_states, cu_seqlens=cu_seqlens, inference_params=inference_params, **mixer_kwargs) + hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs) if self.mlp is not None: if not self.fused_add_norm: @@ -88,4 +88,4 @@ def forward( return hidden_states, residual def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) \ No newline at end of file diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 04b1ba73..7bf976c4 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -317,3 +317,59 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states conv_state.zero_() ssm_state.zero_() return conv_state, ssm_state + +class Block(nn.Module): + def __init__( + self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False + ): + """ + Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA/MLP -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Add -> LN -> Mixer, returning both + the hidden_states (output of the mixer) and the residual. + This is purely for performance reasons, as we can fuse add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.mixer = mixer_cls(dim) + self.norm = norm_cls(dim) + if self.fused_add_norm: + assert RMSNorm is not None, "RMSNorm import fails" + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" + + def forward( + self, hidden_states: Tensor, residual: Optional[Tensor] = None, cu_seqlens=None, inference_params=None + ): + r"""Pass the input through the encoder layer. + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Mixer(LN(residual)) + """ + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + hidden_states = self.mixer(hidden_states, cu_seqlens=cu_seqlens, inference_params=inference_params) + return hidden_states, residual + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) \ No newline at end of file From 596943cd21c01be44a03313f56572338f3616f4c Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Tue, 16 Jul 2024 23:28:17 +0800 Subject: [PATCH 21/27] fix unittest for test_selective_state_update_with_heads --- tests/ops/triton/test_selective_state_update.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ops/triton/test_selective_state_update.py b/tests/ops/triton/test_selective_state_update.py index 696b2c77..e81807ae 100644 --- a/tests/ops/triton/test_selective_state_update.py +++ b/tests/ops/triton/test_selective_state_update.py @@ -6,7 +6,7 @@ import torch.nn.functional as F import pytest -from einops import rearrange +from einops import rearrange, repeat from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref From b69b957e6ba2910ca673c810d8b47ae080feccf6 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Fri, 19 Jul 2024 20:46:34 +0800 Subject: [PATCH 22/27] migrate to tridao's native varlen causal_conv1d kernel for speedup --- mamba_ssm/models/mixer_seq_simple.py | 12 +- mamba_ssm/modules/block.py | 2 +- mamba_ssm/modules/mamba_simple.py | 105 +++---------- mamba_ssm/ops/selective_scan_interface.py | 147 +++++++++--------- .../ops/test_mamba_cu_seqlens_equivalence.py | 28 +++- 5 files changed, 122 insertions(+), 172 deletions(-) diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index 06ecb183..fae2257a 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -16,7 +16,6 @@ from mamba_ssm.modules.mamba2 import Mamba2 from mamba_ssm.modules.mha import MHA from mamba_ssm.modules.mlp import GatedMLP -from mamba_ssm.modules.mamba_simple import Block as Block_Mamba1 from mamba_ssm.modules.block import Block from mamba_ssm.utils.generation import GenerationMixin from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf @@ -71,8 +70,7 @@ def create_block( mlp_cls = partial( GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs ) - block_cls = Block if ssm_layer == "Mamba2" else Block_Mamba1 - block = block_cls( + block = Block( d_model, mixer_cls, mlp_cls, @@ -189,12 +187,12 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs) for i, layer in enumerate(self.layers) } - def forward(self, input_ids, cu_seqlens=None, inference_params=None, **mixer_kwargs): + def forward(self, input_ids, inference_params=None, **mixer_kwargs): hidden_states = self.embedding(input_ids) residual = None for layer in self.layers: hidden_states, residual = layer( - hidden_states, residual, cu_seqlens=cu_seqlens, inference_params=inference_params, **mixer_kwargs + hidden_states, residual, inference_params=inference_params, **mixer_kwargs ) if not self.fused_add_norm: residual = (hidden_states + residual) if residual is not None else hidden_states @@ -273,12 +271,12 @@ def tie_weights(self): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, cu_seqlens=None, **mixer_kwargs): + def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs): """ "position_ids" is just to be compatible with Transformer generation. We don't use it. num_last_tokens: if > 0, only return the logits for the last n tokens """ - hidden_states = self.backbone(input_ids, cu_seqlens=cu_seqlens, inference_params=inference_params, **mixer_kwargs) + hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs) if num_last_tokens > 0: hidden_states = hidden_states[:, -num_last_tokens:] lm_logits = self.lm_head(hidden_states) diff --git a/mamba_ssm/modules/block.py b/mamba_ssm/modules/block.py index 5fdaff05..b0ed44e1 100644 --- a/mamba_ssm/modules/block.py +++ b/mamba_ssm/modules/block.py @@ -88,4 +88,4 @@ def forward( return hidden_states, residual def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) \ No newline at end of file + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 7bf976c4..d27ec31d 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -10,7 +10,7 @@ from einops import rearrange, repeat -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, selective_scan_ref +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -119,10 +119,12 @@ def __init__( def forward(self, hidden_states, cu_seqlens=None, inference_params=None): """ hidden_states: (B, L, D) - cu_seqlens: one-dimensional tensor representing cumulative start indexes of packed sequence, a.k.a., B=1 + cu_seqlens: (Optional) cumulative sum of the sequence lengths, starting from 0 and end with L, and must already be sorted. Returns: same shape as hidden_states """ batch, seqlen, dim = hidden_states.shape + if cu_seqlens is not None: + assert batch == 1 and cu_seqlens.ndimension() == 1, "varlen mamba1 is only supported with B=1" conv_state, ssm_state = None, None if inference_params is not None: @@ -158,46 +160,40 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None): self.D.float(), delta_bias=self.dt_proj.bias.float(), delta_softplus=True, - cu_seqlens=cu_seqlens, - d_conv=torch.tensor(self.d_conv) + cu_seqlens=cu_seqlens ) else: x, z = xz.chunk(2, dim=1) - - # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences - if cu_seqlens is not None: - padded_x = x - count = 0 - for idx in cu_seqlens[1:-1].tolist(): - padded_idx = idx + count*(self.d_conv - 1) - padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], self.d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) - count = count + 1 - x = padded_x - # Compute short convolution if conv_state is not None: # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) + if cu_seqlens is not None: + # naive pure python implementation of varlen causal_conv1d + for i, s in enumerate(cu_seqlens[1:-1]): + x = torch.cat((x[..., :s + i*(self.d_conv - 1)], torch.zeros_like(x[..., :(self.d_conv - 1)]), x[..., s + i*(self.d_conv - 1):]), dim=2) + mask = torch.cat([torch.cat((torch.full((s,), True, dtype=torch.bool, device=x.device), + torch.full((self.d_conv - 1,), False, dtype=torch.bool, device=x.device)), dim=0) + for s in (cu_seqlens[1:] - cu_seqlens[:-1])], dim=0) + x = self.act(self.conv1d(x)[:, :, mask]) + else: + x = self.act(self.conv1d(x)[..., :seqlen]) else: assert self.activation in ["silu", "swish"] + if cu_seqlens is not None: + seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0) + else: + seq_idx = None x = causal_conv1d_fn( - x=x, + x=x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, + seq_idx=seq_idx, activation=self.activation, ) - - # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences - if cu_seqlens is not None: - mask = [] - for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): - mask.extend([True] * seq_len) - mask.extend([False] * (self.d_conv - 1)) - mask = mask[:-(self.d_conv - 1)] - x = x[:, :, mask] # We're careful here about the layout, to avoid extra transposes. # We want dt to have d as the slowest moving dimension @@ -208,7 +204,6 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None): dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - assert self.activation in ["silu", "swish"] y = selective_scan_fn( x, @@ -317,59 +312,3 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states conv_state.zero_() ssm_state.zero_() return conv_state, ssm_state - -class Block(nn.Module): - def __init__( - self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False - ): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.mixer = mixer_cls(dim) - self.norm = norm_cls(dim) - if self.fused_add_norm: - assert RMSNorm is not None, "RMSNorm import fails" - assert isinstance( - self.norm, (nn.LayerNorm, RMSNorm) - ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - - def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, cu_seqlens=None, inference_params=None - ): - r"""Pass the input through the encoder layer. - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Mixer(LN(residual)) - """ - if not self.fused_add_norm: - residual = (hidden_states + residual) if residual is not None else hidden_states - hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn - hidden_states, residual = fused_add_norm_fn( - hidden_states, - self.norm.weight, - self.norm.bias, - residual=residual, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - eps=self.norm.eps, - ) - hidden_states = self.mixer(hidden_states, cu_seqlens=cu_seqlens, inference_params=inference_params) - return hidden_states, residual - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) \ No newline at end of file diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index d641380e..fa43b38b 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -169,12 +169,19 @@ class MambaInnerFn(torch.autograd.Function): def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None, checkpoint_lvl=1): + C_proj_bias=None, delta_softplus=True, cu_seqlens=None, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." assert checkpoint_lvl in [0, 1] + + if cu_seqlens is not None: + seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0) + else: + seq_idx = None + L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) @@ -188,28 +195,18 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) - - # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences - if cu_seqlens is not None: - padded_x = x - count = 0 - for idx in cu_seqlens[1:-1].tolist(): - padded_idx = idx + count*(d_conv - 1) - padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) - count = count + 1 - x = padded_x - - conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True) - - # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences - if cu_seqlens is not None: - mask = [] - for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): - mask.extend([True] * seq_len) - mask.extend([False] * (d_conv - 1)) - mask = mask[:-(d_conv - 1)] - conv1d_out = conv1d_out[:, :, mask] + conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( + x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, + conv1d_weight, + conv1d_bias, + seq_idx, + None, + None, + True + ) + if conv1d_out.stride(-1) != 1: + conv1d_out = conv1d_out.contiguous() # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension @@ -248,7 +245,16 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh if D is not None: D = D.contiguous() out, scan_intermediates, out_z = selective_scan_cuda.fwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens + conv1d_out, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + cu_seqlens ) ctx.delta_softplus = delta_softplus ctx.out_proj_bias_is_None = out_proj_bias is None @@ -257,7 +263,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, - A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, d_conv) + A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, seq_idx) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @@ -266,40 +272,27 @@ def backward(ctx, dout): # dout: (batch, seqlen, dim) assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, - conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, d_conv) = ctx.saved_tensors + conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens, seq_idx) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) if dout.stride(-1) != 1: dout = dout.contiguous() - - x_bak = x if ctx.checkpoint_lvl == 1: - # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences - if cu_seqlens is not None: - padded_x = x - count = 0 - for idx in cu_seqlens[1:-1].tolist(): - padded_idx = idx + count*(d_conv - 1) - padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) - count = count + 1 - x = padded_x - - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True) - - # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences - if cu_seqlens is not None: - mask = [] - for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): - mask.extend([True] * seq_len) - mask.extend([False] * (d_conv - 1)) - mask = mask[:-(d_conv - 1)] - conv1d_out = conv1d_out[:, :, mask] - + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( + x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, + conv1d_weight, + conv1d_bias, + seq_idx, + None, + None, + True + ) + if conv1d_out.stride(-1) != 1: + conv1d_out = conv1d_out.contiguous() delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) - x = x_bak # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). @@ -345,33 +338,42 @@ def backward(ctx, dout): # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( - x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True + x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, + conv1d_weight, + conv1d_bias, + dconv1d_out, + seq_idx, + None, + None, + dx.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else dx, + False, + True ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") - return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, + return (torch.cat((dx, dz), dim=1) if cu_seqlens is not None else dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, dout_proj_weight, dout_proj_bias, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, - dB_proj_bias, dC_proj_bias, None, None, None) + dB_proj_bias, dC_proj_bias, None, None) def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None + C_proj_bias=None, delta_softplus=True, cu_seqlens=None ): return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, - A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, cu_seqlens, d_conv) + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, cu_seqlens) def mamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, cu_seqlens=None, d_conv=None + C_proj_bias=None, delta_softplus=True, cu_seqlens=None ): assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." L = xz.shape[-1] @@ -379,26 +381,19 @@ def mamba_inner_ref( d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) - # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences if cu_seqlens is not None: - padded_x = x - count = 0 - for idx in cu_seqlens[1:-1].tolist(): - padded_idx = idx + count*(d_conv - 1) - padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2) - count = count + 1 - x = padded_x - - x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu") - - # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences - if cu_seqlens is not None: - mask = [] - for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist(): - mask.extend([True] * seq_len) - mask.extend([False] * (d_conv - 1)) - mask = mask[:-(d_conv - 1)] - x = x[:, :, mask] + seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0) + else: + seq_idx = None + + x = causal_conv1d_fn( + x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, + rearrange(conv1d_weight, "d 1 w -> d w"), + conv1d_bias, + seq_idx=seq_idx, + activation="silu" + ) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension @@ -422,5 +417,5 @@ def mamba_inner_ref( C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() - y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) + y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True, cu_seqlens=cu_seqlens) return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) \ No newline at end of file diff --git a/tests/ops/test_mamba_cu_seqlens_equivalence.py b/tests/ops/test_mamba_cu_seqlens_equivalence.py index 01d066ef..1937421f 100644 --- a/tests/ops/test_mamba_cu_seqlens_equivalence.py +++ b/tests/ops/test_mamba_cu_seqlens_equivalence.py @@ -1,3 +1,4 @@ +import copy import random import torch @@ -82,7 +83,7 @@ def main(): assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1] # creat one simple mamba block - mamba = Mamba( + mamba_ref = Mamba( # This module uses roughly 3 * expand * d_model^2 parameters d_model=hidden_dim, # Model dimension d_model d_state=16, # SSM state expansion factor @@ -91,17 +92,34 @@ def main(): ).to(device) # reference output for forwardding hidden_states - out_ref = mamba(hidden_states) - out_ref = pack(out_ref, cu_seqlens).unsqueeze(0) + out_ref_original = mamba_ref(hidden_states) + out_ref = pack(out_ref_original, cu_seqlens).unsqueeze(0) # output for forwardding packed_hidden_states with cu_seqlens + mamba = copy.deepcopy(mamba_ref) out = mamba(packed_hidden_states, cu_seqlens) # Testing the max/mean diff - print(f'Output max diff: {(out - out_ref).abs().max().item()}') - print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + print(f'Output max diff for output in varlen_mamba fwd pass: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff for output in varlen_mamba fwd pass: {(out - out_ref).abs().mean().item()}') assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + # bwd for mamba w/ cu_seqlens + g = torch.randn_like(out) + out.backward(g) + mamba_grad = {name: param.grad.clone() for name, param in mamba.named_parameters()} + + # bwd for mamba wo/ cu_seqlens + g_ref = unpack(g, cu_seqlens) + out_ref_original.backward(g_ref) + mamba_ref_grad = {name: param.grad.clone() for name, param in mamba_ref.named_parameters()} + + # check bwd pass + assert set(mamba_grad.keys()) == set(mamba_ref_grad.keys()) + for name in mamba_ref_grad: + print(f'Output max diff for {name} in varlen_mamba bwd pass: {( - mamba_ref_grad[name]).abs().max().item()}') + print(f'Output mean diff for {name} in varlen_mamba bwd pass: {(mamba_grad[name] - mamba_ref_grad[name]).abs().mean().item()}') + assert torch.allclose(mamba_grad[name], mamba_ref_grad[name], rtol=rtol, atol=atol) if __name__ == "__main__": main() \ No newline at end of file From 909f9706e547f3e06e134db7513e2b86bcd6f205 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Tue, 23 Jul 2024 18:30:10 +0800 Subject: [PATCH 23/27] typo fix --- tests/ops/test_mamba_cu_seqlens_equivalence.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/ops/test_mamba_cu_seqlens_equivalence.py b/tests/ops/test_mamba_cu_seqlens_equivalence.py index 1937421f..772b3b6f 100644 --- a/tests/ops/test_mamba_cu_seqlens_equivalence.py +++ b/tests/ops/test_mamba_cu_seqlens_equivalence.py @@ -100,8 +100,8 @@ def main(): out = mamba(packed_hidden_states, cu_seqlens) # Testing the max/mean diff - print(f'Output max diff for output in varlen_mamba fwd pass: {(out - out_ref).abs().max().item()}') - print(f'Output mean diff for output in varlen_mamba fwd pass: {(out - out_ref).abs().mean().item()}') + print(f'max diff for output in varlen_mamba fwd pass: {(out - out_ref).abs().max().item()}') + print(f'mean diff for output in varlen_mamba fwd pass: {(out - out_ref).abs().mean().item()}') assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) # bwd for mamba w/ cu_seqlens @@ -117,9 +117,9 @@ def main(): # check bwd pass assert set(mamba_grad.keys()) == set(mamba_ref_grad.keys()) for name in mamba_ref_grad: - print(f'Output max diff for {name} in varlen_mamba bwd pass: {( - mamba_ref_grad[name]).abs().max().item()}') - print(f'Output mean diff for {name} in varlen_mamba bwd pass: {(mamba_grad[name] - mamba_ref_grad[name]).abs().mean().item()}') + print(f'max diff for {name} in varlen_mamba bwd pass: {(mamba_grad[name] - mamba_ref_grad[name]).abs().max().item()}') + print(f'mean diff for {name} in varlen_mamba bwd pass: {(mamba_grad[name] - mamba_ref_grad[name]).abs().mean().item()}') assert torch.allclose(mamba_grad[name], mamba_ref_grad[name], rtol=rtol, atol=atol) if __name__ == "__main__": - main() \ No newline at end of file + main() From 8174c453e27c7f450631b95f44ea7f6aa026618d Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Mon, 5 Aug 2024 18:13:57 +0800 Subject: [PATCH 24/27] use seq_idx if provided, or compute it by cu_seqlens --- mamba_ssm/modules/mamba_simple.py | 12 ++++++++++-- mamba_ssm/ops/selective_scan_interface.py | 24 ++++++----------------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index d27ec31d..e00d8e02 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -116,15 +116,22 @@ def __init__( self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - def forward(self, hidden_states, cu_seqlens=None, inference_params=None): + def forward(self, hidden_states, cu_seqlens=None, seq_idx=None, inference_params=None): """ hidden_states: (B, L, D) cu_seqlens: (Optional) cumulative sum of the sequence lengths, starting from 0 and end with L, and must already be sorted. Returns: same shape as hidden_states """ batch, seqlen, dim = hidden_states.shape + if cu_seqlens is not None: assert batch == 1 and cu_seqlens.ndimension() == 1, "varlen mamba1 is only supported with B=1" + # compute seq_idx if not provided + if seq_idx is None: + seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0) + else: + seq_idx = None conv_state, ssm_state = None, None if inference_params is not None: @@ -160,7 +167,8 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None): self.D.float(), delta_bias=self.dt_proj.bias.float(), delta_softplus=True, - cu_seqlens=cu_seqlens + cu_seqlens=cu_seqlens, + seq_idx=seq_idx, ) else: x, z = xz.chunk(2, dim=1) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index fa43b38b..5d654d83 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -169,19 +169,13 @@ class MambaInnerFn(torch.autograd.Function): def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, cu_seqlens=None, checkpoint_lvl=1): + C_proj_bias=None, delta_softplus=True, cu_seqlens=None, seq_idx=None, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." assert checkpoint_lvl in [0, 1] - - if cu_seqlens is not None: - seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) - for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0) - else: - seq_idx = None - + L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) @@ -355,25 +349,25 @@ def backward(ctx, dout): dout_proj_weight, dout_proj_bias, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, - dB_proj_bias, dC_proj_bias, None, None) + dB_proj_bias, dC_proj_bias, None, None, None) def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, cu_seqlens=None + C_proj_bias=None, delta_softplus=True, cu_seqlens=None, seq_idx=None, ): return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, - A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, cu_seqlens) + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, cu_seqlens, seq_idx) def mamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, cu_seqlens=None + C_proj_bias=None, delta_softplus=True, cu_seqlens=None, seq_idx=None, ): assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." L = xz.shape[-1] @@ -381,12 +375,6 @@ def mamba_inner_ref( d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) - if cu_seqlens is not None: - seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) - for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0) - else: - seq_idx = None - x = causal_conv1d_fn( x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, rearrange(conv1d_weight, "d 1 w -> d w"), From 59be631bef55b653d2e9f3024d1ab10847702e2c Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Mon, 5 Aug 2024 18:16:35 +0800 Subject: [PATCH 25/27] use seq_idx if provided, or compute it by cu_seqlens --- mamba_ssm/modules/mamba_simple.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index e00d8e02..ab7a65f9 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -190,11 +190,6 @@ def forward(self, hidden_states, cu_seqlens=None, seq_idx=None, inference_params x = self.act(self.conv1d(x)[..., :seqlen]) else: assert self.activation in ["silu", "swish"] - if cu_seqlens is not None: - seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) - for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0) - else: - seq_idx = None x = causal_conv1d_fn( x=x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), From 210b6f6ffba7bd838902a2877a1dfdbf26fa6641 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Wed, 7 Aug 2024 15:04:10 +0800 Subject: [PATCH 26/27] mv cu_seqlens in ssm kernel to smem --- .../selective_scan_bwd_kernel.cuh | 34 ++++++++++++------ .../selective_scan_fwd_kernel.cuh | 36 +++++++++++++------ 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index 75f50c11..a1797558 100755 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -142,7 +142,15 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; float dD_val = 0; float ddelta_bias_val = 0; - long *cu_seqlens = reinterpret_cast(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride; + + // Load cu_seqlens into shared memory + const int cu_seqlens_size = params.cu_seqlens_size; + long *cu_seqlens = reinterpret_cast(params.cu_seqlens_ptr); + __shared__ long smem_cu_seqlens[1024]; // Adjust size as needed + for (int i = threadIdx.x; i < cu_seqlens_size; i += blockDim.x) { + smem_cu_seqlens[i] = cu_seqlens[i]; + } + __syncthreads(); constexpr int kChunkSize = kNThreads * kNItems; u += (params.n_chunks - 1) * kChunkSize; @@ -255,15 +263,17 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { // Reset A bar for cumulative sequences (Real) int left = 1; - int right = params.cu_seqlens_size - 2; + int right = cu_seqlens_size - 2; + int idx = threadIdx.x * kNItems + i + chunk * kChunkSize; while (left <= right) { - if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { + int mid = (left + right) >> 1; + if (smem_cu_seqlens[mid] == idx) { delta_a_exp = 0.f; break; - } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { - left = ((left + right) >> 1) + 1; + } else if (smem_cu_seqlens[mid] < idx) { + left = mid + 1; } else { - right = ((left + right) >> 1) - 1; + right = mid - 1; } } @@ -358,16 +368,18 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { // Reset A bar for cumulative sequences (Complex) int left = 1; - int right = params.cu_seqlens_size - 2; + int right = cu_seqlens_size - 2; + int idx = threadIdx.x * kNItems + i + chunk * kChunkSize; while (left <= right) { - if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { + int mid = (left + right) >> 1; + if (smem_cu_seqlens[mid] == idx) { delta_a_exp.real_ = 0.f; delta_a_exp.imag_ = 0.f; break; - } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { - left = ((left + right) >> 1) + 1; + } else if (smem_cu_seqlens[mid] < idx) { + left = mid + 1; } else { - right = ((left + right) >> 1) - 1; + right = mid - 1; } } diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index c8f0b82a..e1d1577c 100755 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -112,7 +112,16 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; - long *cu_seqlens = reinterpret_cast(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride; + + // Load cu_seqlens into shared memory + const int cu_seqlens_size = params.cu_seqlens_size; + long *cu_seqlens = reinterpret_cast(params.cu_seqlens_ptr); + __shared__ long smem_cu_seqlens[1024]; // Adjust size as needed + for (int i = threadIdx.x; i < cu_seqlens_size; i += blockDim.x) { + smem_cu_seqlens[i] = cu_seqlens[i]; + } + __syncthreads(); + float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { @@ -224,15 +233,17 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // Reset A bar for cumulative sequences (Real) int left = 1; - int right = params.cu_seqlens_size - 2; + int right = cu_seqlens_size - 2; + int idx = threadIdx.x * kNItems + i + chunk * kChunkSize; while (left <= right) { - if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { + int mid = (left + right) >> 1; + if (smem_cu_seqlens[mid] == idx) { thread_data[i].x = 0.f; break; - } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { - left = ((left + right) >> 1) + 1; + } else if (smem_cu_seqlens[mid] < idx) { + left = mid + 1; } else { - right = ((left + right) >> 1) - 1; + right = mid - 1; } } @@ -249,19 +260,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // Reset A bar for cumulative sequences (Complex) int left = 1; - int right = params.cu_seqlens_size - 2; + int right = cu_seqlens_size - 2; + int idx = threadIdx.x * kNItems + i + chunk * kChunkSize; while (left <= right) { - if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { + int mid = (left + right) >> 1; + if (smem_cu_seqlens[mid] == idx) { thread_data[i].x = 0.f; thread_data[i].y = 0.f; break; - } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { - left = ((left + right) >> 1) + 1; + } else if (smem_cu_seqlens[mid] < idx) { + left = mid + 1; } else { - right = ((left + right) >> 1) - 1; + right = mid - 1; } } + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f); From cda4b5aa243d27b42db90eaa61d448074a46ffe1 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Wed, 7 Aug 2024 16:30:25 +0800 Subject: [PATCH 27/27] remove smem implementation because const vals and bi-search is enough --- .../selective_scan_bwd_kernel.cuh | 16 +++++----------- .../selective_scan_fwd_kernel.cuh | 17 +++++------------ 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index a1797558..d2bce9ea 100755 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -143,14 +143,8 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { float dD_val = 0; float ddelta_bias_val = 0; - // Load cu_seqlens into shared memory const int cu_seqlens_size = params.cu_seqlens_size; - long *cu_seqlens = reinterpret_cast(params.cu_seqlens_ptr); - __shared__ long smem_cu_seqlens[1024]; // Adjust size as needed - for (int i = threadIdx.x; i < cu_seqlens_size; i += blockDim.x) { - smem_cu_seqlens[i] = cu_seqlens[i]; - } - __syncthreads(); + const long *cu_seqlens = reinterpret_cast(params.cu_seqlens_ptr); constexpr int kChunkSize = kNThreads * kNItems; u += (params.n_chunks - 1) * kChunkSize; @@ -267,10 +261,10 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { int idx = threadIdx.x * kNItems + i + chunk * kChunkSize; while (left <= right) { int mid = (left + right) >> 1; - if (smem_cu_seqlens[mid] == idx) { + if (cu_seqlens[mid] == idx) { delta_a_exp = 0.f; break; - } else if (smem_cu_seqlens[mid] < idx) { + } else if (cu_seqlens[mid] < idx) { left = mid + 1; } else { right = mid - 1; @@ -372,11 +366,11 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { int idx = threadIdx.x * kNItems + i + chunk * kChunkSize; while (left <= right) { int mid = (left + right) >> 1; - if (smem_cu_seqlens[mid] == idx) { + if (cu_seqlens[mid] == idx) { delta_a_exp.real_ = 0.f; delta_a_exp.imag_ = 0.f; break; - } else if (smem_cu_seqlens[mid] < idx) { + } else if (cu_seqlens[mid] < idx) { left = mid + 1; } else { right = mid - 1; diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index e1d1577c..fe083fb4 100755 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -113,15 +113,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; - // Load cu_seqlens into shared memory const int cu_seqlens_size = params.cu_seqlens_size; - long *cu_seqlens = reinterpret_cast(params.cu_seqlens_ptr); - __shared__ long smem_cu_seqlens[1024]; // Adjust size as needed - for (int i = threadIdx.x; i < cu_seqlens_size; i += blockDim.x) { - smem_cu_seqlens[i] = cu_seqlens[i]; - } - __syncthreads(); - + const long *cu_seqlens = reinterpret_cast(params.cu_seqlens_ptr); float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { @@ -237,10 +230,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { int idx = threadIdx.x * kNItems + i + chunk * kChunkSize; while (left <= right) { int mid = (left + right) >> 1; - if (smem_cu_seqlens[mid] == idx) { + if (cu_seqlens[mid] == idx) { thread_data[i].x = 0.f; break; - } else if (smem_cu_seqlens[mid] < idx) { + } else if (cu_seqlens[mid] < idx) { left = mid + 1; } else { right = mid - 1; @@ -264,11 +257,11 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { int idx = threadIdx.x * kNItems + i + chunk * kChunkSize; while (left <= right) { int mid = (left + right) >> 1; - if (smem_cu_seqlens[mid] == idx) { + if (cu_seqlens[mid] == idx) { thread_data[i].x = 0.f; thread_data[i].y = 0.f; break; - } else if (smem_cu_seqlens[mid] < idx) { + } else if (cu_seqlens[mid] < idx) { left = mid + 1; } else { right = mid - 1;