From 23f42a16df513f8d72eb1a54307ed42bfa714bac Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 23 Jul 2024 03:31:40 +0300 Subject: [PATCH 1/6] Working, need to clean up --- csrc/selective_scan/selective_scan.cpp | 11 ++++++-- .../selective_scan_fwd_kernel.cuh | 13 +++++++++- mamba_ssm/ops/selective_scan_interface.py | 26 ++++++++++++++++--- 3 files changed, 43 insertions(+), 7 deletions(-) diff --git a/csrc/selective_scan/selective_scan.cpp b/csrc/selective_scan/selective_scan.cpp index cde867cd..0299a4da 100644 --- a/csrc/selective_scan/selective_scan.cpp +++ b/csrc/selective_scan/selective_scan.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -229,7 +230,7 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, const c10::optional &D_, const c10::optional &z_, const c10::optional &delta_bias_, - bool delta_softplus) { + bool delta_softplus, const c10::optional &prev_state) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -310,7 +311,13 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout at::Tensor out = torch::empty_like(delta); at::Tensor x; - x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); + if (prev_state.has_value()){ + x = prev_state.value(); + } + else { + x = torch::zeros({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); + } + printf(u.sizes()); SSMParamsBase params; set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 80e9e37e..7240f27f 100755 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -241,7 +241,18 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { scan_t running_prefix; if constexpr (!kIsComplex) { // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read - running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f); + if (chunk == 0){ + running_prefix = x[state_idx]; + } + else { + if (threadIdx.x % 32 == 0){ + running_prefix = smem_running_prefix[state_idx + r * MAX_DSTATE]; + } + else { + running_prefix = make_float2(1.f, 0.f); + } + } + // running_prefix = chunk > 0 && threadIdx.x & 32 == 0? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f); // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); } else { running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f); diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index c3596bfe..0dbb7d07 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -1,5 +1,6 @@ # Copyright (c) 2023, Tri Dao, Albert Gu. +import math import torch import torch.nn.functional as F from torch.cuda.amp import custom_bwd, custom_fwd @@ -20,7 +21,7 @@ class SelectiveScanFn(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, prev_state=None): if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: @@ -39,10 +40,27 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + tmp = None + # n_chunks = + tmp = torch.zeros(( + u.shape[0], + u.shape[1], + 1, + int(A.shape[1] * 2), + ),device=u.device,dtype=torch.float32) ## BS , dim, chunks, dstate + if prev_state is not None: + print(u.shape,flush=True) + tmp[:,:,0,0::2] = 1 + tmp[:,:,0,1::2].copy_(prev_state) + else: + # tmp[:,:,1:,1::2] = 1 + tmp[:,:,:,0::2] = 1 + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, tmp) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None + # print(x[:, :, -1, :].mean(-1) ) last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + print(x.dtype) if not ctx.has_z: ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out if not return_last_state else (out, last_state) @@ -80,12 +98,12 @@ def backward(ctx, dout, *args): def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False,prev_state=None): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, prev_state) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, From 26a76ffe287d056fb1228a257218a5ac3262ace4 Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 23 Jul 2024 14:30:48 +0300 Subject: [PATCH 2/6] Clean up --- csrc/selective_scan/selective_scan.cpp | 20 +++++------ .../selective_scan_fwd_kernel.cuh | 15 ++------ mamba_ssm/ops/selective_scan_interface.py | 34 +++++++------------ 3 files changed, 24 insertions(+), 45 deletions(-) diff --git a/csrc/selective_scan/selective_scan.cpp b/csrc/selective_scan/selective_scan.cpp index 0299a4da..3dabb2b7 100644 --- a/csrc/selective_scan/selective_scan.cpp +++ b/csrc/selective_scan/selective_scan.cpp @@ -4,7 +4,6 @@ #include #include -#include #include #include @@ -230,7 +229,7 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, const c10::optional &D_, const c10::optional &z_, const c10::optional &delta_bias_, - bool delta_softplus, const c10::optional &prev_state) { + bool delta_softplus, const c10::optional &x) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -310,21 +309,20 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, // at::Tensor out = torch::empty_like(u); // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout at::Tensor out = torch::empty_like(delta); - at::Tensor x; - if (prev_state.has_value()){ - x = prev_state.value(); + if (x.has_value()){ + auto _x = x.value(); + TORCH_CHECK(_x.scalar_type() == weight_type); + TORCH_CHECK(_x.is_cuda()); + TORCH_CHECK(_x.stride(-1) == 1); + CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2); } - else { - x = torch::zeros({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); - } - printf(u.sizes()); SSMParamsBase params; set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, u, delta, A, B, C, out, z, out_z, D_.has_value() ? D_.value().data_ptr() : nullptr, delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, - x.data_ptr(), + x.value().data_ptr(), has_z, delta_softplus); @@ -337,7 +335,7 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, selective_scan_fwd_cuda(params, stream); }); }); - std::vector result = {out, x}; + std::vector result = {out, x.value()}; if (has_z) { result.push_back(out_z); } return result; } diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 7240f27f..9f05340b 100755 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -241,21 +241,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { scan_t running_prefix; if constexpr (!kIsComplex) { // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read - if (chunk == 0){ - running_prefix = x[state_idx]; - } - else { - if (threadIdx.x % 32 == 0){ - running_prefix = smem_running_prefix[state_idx + r * MAX_DSTATE]; - } - else { - running_prefix = make_float2(1.f, 0.f); - } - } - // running_prefix = chunk > 0 && threadIdx.x & 32 == 0? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f); + running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f)); // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); } else { - running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f); + running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f)); // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); } SSMScanPrefixCallbackOp prefix_op(running_prefix); diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index 0dbb7d07..063b2b7d 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -1,6 +1,5 @@ # Copyright (c) 2023, Tri Dao, Albert Gu. -import math import torch import torch.nn.functional as F from torch.cuda.amp import custom_bwd, custom_fwd @@ -40,27 +39,19 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True - tmp = None - # n_chunks = - tmp = torch.zeros(( - u.shape[0], - u.shape[1], - 1, - int(A.shape[1] * 2), - ),device=u.device,dtype=torch.float32) ## BS , dim, chunks, dstate + n_chunks = int((u.shape[-1] + 2048 - 1) / 2048) + x = torch.zeros( + (u.shape[0], u.shape[1], n_chunks, int(A.shape[1] * 2),), + device=u.device, + dtype=torch.float32 + ) + x[:, :, 0, 0::2] = 1 if prev_state is not None: - print(u.shape,flush=True) - tmp[:,:,0,0::2] = 1 - tmp[:,:,0,1::2].copy_(prev_state) - else: - # tmp[:,:,1:,1::2] = 1 - tmp[:,:,:,0::2] = 1 - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, tmp) + x[:, :, 0, 1::2].copy_(prev_state) + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, x) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None - # print(x[:, :, -1, :].mean(-1) ) last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) - print(x.dtype) if not ctx.has_z: ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out if not return_last_state else (out, last_state) @@ -98,7 +89,7 @@ def backward(ctx, dout, *args): def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False,prev_state=None): + return_last_state=False, prev_state=None): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. @@ -107,7 +98,7 @@ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, prev_state=None): """ u: r(B D L) delta: r(B D L) @@ -117,6 +108,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta D: r(D) z: r(B D L) delta_bias: r(D), fp32 + prev_state: r(B D N), fp32 out: r(B D L) last_state (optional): r(B D dstate) or c(B D dstate) @@ -139,7 +131,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta else: B = B.float() C = C.float() - x = A.new_zeros((batch, dim, dstate)) + x = A.new_zeros((batch, dim, dstate)) if prev_state is not None else prev_state ys = [] deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) if not is_variable_B: From 028fe7b4b38f807cda59c68d45e7612c318ad3b2 Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 23 Jul 2024 22:07:55 +0300 Subject: [PATCH 3/6] Add chunked mamba tests --- mamba_ssm/ops/selective_scan_interface.py | 5 +- tests/ops/test_selective_scan.py | 60 +++++++++++++++++------ 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index 063b2b7d..4bca5e93 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -84,8 +84,7 @@ def backward(ctx, dout, *args): dD if D is not None else None, dz, ddelta_bias if delta_bias is not None else None, - None, - None) + None, None, None) def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, @@ -131,7 +130,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta else: B = B.float() C = C.float() - x = A.new_zeros((batch, dim, dstate)) if prev_state is not None else prev_state + x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state ys = [] deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) if not is_variable_B: diff --git a/tests/ops/test_selective_scan.py b/tests/ops/test_selective_scan.py index 8a834b3c..37ecc656 100644 --- a/tests/ops/test_selective_scan.py +++ b/tests/ops/test_selective_scan.py @@ -35,8 +35,9 @@ @pytest.mark.parametrize("is_variable_C", [True]) # @pytest.mark.parametrize("is_variable_B", [False, True]) @pytest.mark.parametrize("is_variable_B", [True]) +@pytest.mark.parametrize("scan_chunks", [1,2,3]) def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, - delta_softplus, return_last_state, seqlen, itype, wtype): + delta_softplus, return_last_state, seqlen, itype, wtype, scan_chunks): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable device = 'cuda' @@ -92,20 +93,46 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z u_ref = u.detach().clone().requires_grad_() delta_ref = delta.detach().clone().requires_grad_() delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None - out, *rest = selective_scan_fn( - u, delta, A, B, C, D, z=z, - delta_bias=delta_bias, delta_softplus=delta_softplus, - return_last_state=return_last_state - ) - if return_last_state: - state = rest[0] - out_ref, *rest = selective_scan_ref( - u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, - delta_bias=delta_bias_ref, delta_softplus=delta_softplus, - return_last_state=return_last_state - ) - if return_last_state: - state_ref = rest[0] + state = None + state_ref = None + for c in range(scan_chunks): + chunked_prompt_len = seqlen // scan_chunks + chunk_start = chunked_prompt_len * c + chunk_end = chunked_prompt_len * (c + 1) + if c == scan_chunks - 1: + chunk_end = seqlen + _B = B + if is_variable_B: + _B = B[...,chunk_start:chunk_end] + _C = C + if is_variable_B: + _C = C[...,chunk_start:chunk_end] + _z = z + if has_z: + _z = z[...,chunk_start:chunk_end] + out, *rest = selective_scan_fn( + u[...,chunk_start:chunk_end], delta[...,chunk_start:chunk_end], A, _B, _C, D, z=_z, + delta_bias=delta_bias, delta_softplus=delta_softplus, + return_last_state=return_last_state,prev_state=state if c > 0 else None + ) + if return_last_state: + state = rest[0] + _B_ref = B_ref + if is_variable_B: + _B_ref = B_ref[...,chunk_start:chunk_end] + _C_ref = C_ref + if is_variable_B: + _C_ref = C_ref[...,chunk_start:chunk_end] + _z_ref = z_ref + if has_z: + _z_ref = z_ref[...,chunk_start:chunk_end] + out_ref, *rest = selective_scan_ref( + u_ref[...,chunk_start:chunk_end], delta_ref[...,chunk_start:chunk_end], A_ref, _B_ref, _C_ref, D_ref, z=_z_ref, + delta_bias=delta_bias_ref, delta_softplus=delta_softplus, + return_last_state=return_last_state,prev_state=state_ref if c > 0 else None + ) + if return_last_state: + state_ref = rest[0] # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) # dt_u = delta * u @@ -115,6 +142,9 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z if return_last_state: print(f'State max diff: {(state - state_ref).abs().max().item()}') assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + if scan_chunks > 1: + ## skip grad test in case of scan chunks ( not supported atm ) + return g = torch.randn_like(out) out_ref.backward(g) From 59f20cb6ddb79b9d19c42479cf1bec6aec0d83ac Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 23 Jul 2024 22:28:36 +0300 Subject: [PATCH 4/6] Test chunked vs not chuned --- tests/ops/test_selective_scan.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/tests/ops/test_selective_scan.py b/tests/ops/test_selective_scan.py index 37ecc656..356653a3 100644 --- a/tests/ops/test_selective_scan.py +++ b/tests/ops/test_selective_scan.py @@ -95,6 +95,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None state = None state_ref = None + outs = [] for c in range(scan_chunks): chunked_prompt_len = seqlen // scan_chunks chunk_start = chunked_prompt_len * c @@ -115,24 +116,18 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z delta_bias=delta_bias, delta_softplus=delta_softplus, return_last_state=return_last_state,prev_state=state if c > 0 else None ) + outs.append(out) if return_last_state: state = rest[0] - _B_ref = B_ref - if is_variable_B: - _B_ref = B_ref[...,chunk_start:chunk_end] - _C_ref = C_ref - if is_variable_B: - _C_ref = C_ref[...,chunk_start:chunk_end] - _z_ref = z_ref - if has_z: - _z_ref = z_ref[...,chunk_start:chunk_end] - out_ref, *rest = selective_scan_ref( - u_ref[...,chunk_start:chunk_end], delta_ref[...,chunk_start:chunk_end], A_ref, _B_ref, _C_ref, D_ref, z=_z_ref, - delta_bias=delta_bias_ref, delta_softplus=delta_softplus, - return_last_state=return_last_state,prev_state=state_ref if c > 0 else None - ) - if return_last_state: - state_ref = rest[0] + if len(outs) > 1: + out = torch.cat(outs,dim=-1) + out_ref, *rest = selective_scan_ref( + u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, + delta_bias=delta_bias_ref, delta_softplus=delta_softplus, + return_last_state=return_last_state + ) + if return_last_state: + state_ref = rest[0] # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) # dt_u = delta * u From 643dfbf91f3969c49afc300fb1e064c4d90d8909 Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 23 Jul 2024 22:30:42 +0300 Subject: [PATCH 5/6] Comments --- mamba_ssm/ops/selective_scan_interface.py | 2 +- tests/ops/test_selective_scan.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index 4bca5e93..8b82f57f 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -90,7 +90,7 @@ def backward(ctx, dout, *args): def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, prev_state=None): """if return_last_state is True, returns (out, last_state) - last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + last_state has shape (batch, dim, dstate). Note that the gradient of the last state and prev_state (if provided) is not considered in the backward pass. """ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, prev_state) diff --git a/tests/ops/test_selective_scan.py b/tests/ops/test_selective_scan.py index 356653a3..10e66a71 100644 --- a/tests/ops/test_selective_scan.py +++ b/tests/ops/test_selective_scan.py @@ -138,7 +138,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z print(f'State max diff: {(state - state_ref).abs().max().item()}') assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) if scan_chunks > 1: - ## skip grad test in case of scan chunks ( not supported atm ) + # skip grad test in case of scan chunks ( not supported atm ) return g = torch.randn_like(out) From 84a2ec8b16d5180f9edeeb222f9e6177f4b26363 Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 24 Jul 2024 11:00:33 +0300 Subject: [PATCH 6/6] Add assert in backward pass --- mamba_ssm/ops/selective_scan_interface.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index 8b82f57f..a93af07f 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -43,7 +43,8 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp x = torch.zeros( (u.shape[0], u.shape[1], n_chunks, int(A.shape[1] * 2),), device=u.device, - dtype=torch.float32 + dtype=torch.float32, + requires_grad=u.requires_grad ) x[:, :, 0, 0::2] = 1 if prev_state is not None: @@ -53,21 +54,22 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if not ctx.has_z: - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x, prev_state) return out if not return_last_state else (out, last_state) else: - ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out, prev_state) out_z = rest[0] return out_z if not return_last_state else (out_z, last_state) @staticmethod def backward(ctx, dout, *args): if not ctx.has_z: - u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + u, delta, A, B, C, D, delta_bias, x, prev_state = ctx.saved_tensors z = None out = None else: - u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + u, delta, A, B, C, D, z, delta_bias, x, out, prev_state = ctx.saved_tensors + assert prev_state is None, "providing prev_state is not supported in training configuration" if dout.stride(-1) != 1: dout = dout.contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the