Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Initial state support for Mamba SSM (1) #488

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions csrc/selective_scan/selective_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
const c10::optional<at::Tensor> &D_,
const c10::optional<at::Tensor> &z_,
const c10::optional<at::Tensor> &delta_bias_,
bool delta_softplus) {
bool delta_softplus, const c10::optional<at::Tensor> &x) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
Expand Down Expand Up @@ -309,15 +309,20 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
// at::Tensor out = torch::empty_like(u);
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at::Tensor out = torch::empty_like(delta);
at::Tensor x;
x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
if (x.has_value()){
auto _x = x.value();
TORCH_CHECK(_x.scalar_type() == weight_type);
TORCH_CHECK(_x.is_cuda());
TORCH_CHECK(_x.stride(-1) == 1);
CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2);
}

SSMParamsBase params;
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
u, delta, A, B, C, out, z, out_z,
D_.has_value() ? D_.value().data_ptr() : nullptr,
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
x.data_ptr(),
x.value().data_ptr(),
has_z,
delta_softplus);

Expand All @@ -330,7 +335,7 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
});
});
std::vector<at::Tensor> result = {out, x};
std::vector<at::Tensor> result = {out, x.value()};
if (has_z) { result.push_back(out_z); }
return result;
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/selective_scan/selective_scan_fwd_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
scan_t running_prefix;
if constexpr (!kIsComplex) {
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f);
running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f));
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
} else {
running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f);
running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f));
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
}
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
Expand Down
37 changes: 24 additions & 13 deletions mamba_ssm/ops/selective_scan_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class SelectiveScanFn(torch.autograd.Function):

@staticmethod
def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
return_last_state=False, prev_state=None):
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
Expand All @@ -39,26 +39,37 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp
if C.dim() == 3:
C = rearrange(C, "b dstate l -> b 1 dstate l")
ctx.squeeze_C = True
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
n_chunks = int((u.shape[-1] + 2048 - 1) / 2048)
x = torch.zeros(
(u.shape[0], u.shape[1], n_chunks, int(A.shape[1] * 2),),
device=u.device,
dtype=torch.float32,
requires_grad=u.requires_grad
)
x[:, :, 0, 0::2] = 1
if prev_state is not None:
x[:, :, 0, 1::2].copy_(prev_state)
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, x)
ctx.delta_softplus = delta_softplus
ctx.has_z = z is not None
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if not ctx.has_z:
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x, prev_state)
return out if not return_last_state else (out, last_state)
else:
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out, prev_state)
out_z = rest[0]
return out_z if not return_last_state else (out_z, last_state)

@staticmethod
def backward(ctx, dout, *args):
if not ctx.has_z:
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
u, delta, A, B, C, D, delta_bias, x, prev_state = ctx.saved_tensors
z = None
out = None
else:
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
u, delta, A, B, C, D, z, delta_bias, x, out, prev_state = ctx.saved_tensors
assert prev_state is None, "providing prev_state is not supported in training configuration"
if dout.stride(-1) != 1:
dout = dout.contiguous()
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
Expand All @@ -75,21 +86,20 @@ def backward(ctx, dout, *args):
dD if D is not None else None,
dz,
ddelta_bias if delta_bias is not None else None,
None,
None)
None, None, None)


def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
return_last_state=False, prev_state=None):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
last_state has shape (batch, dim, dstate). Note that the gradient of the last state and prev_state (if provided) is
not considered in the backward pass.
"""
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, prev_state)


def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
return_last_state=False, prev_state=None):
"""
u: r(B D L)
delta: r(B D L)
Expand All @@ -99,6 +109,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
prev_state: r(B D N), fp32

out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
Expand All @@ -121,7 +132,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state
ys = []
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
if not is_variable_B:
Expand Down
41 changes: 33 additions & 8 deletions tests/ops/test_selective_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
@pytest.mark.parametrize("is_variable_C", [True])
# @pytest.mark.parametrize("is_variable_B", [False, True])
@pytest.mark.parametrize("is_variable_B", [True])
@pytest.mark.parametrize("scan_chunks", [1,2,3])
def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias,
delta_softplus, return_last_state, seqlen, itype, wtype):
delta_softplus, return_last_state, seqlen, itype, wtype, scan_chunks):
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
pytest.skip() # This config is not applicable
device = 'cuda'
Expand Down Expand Up @@ -92,13 +93,34 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z
u_ref = u.detach().clone().requires_grad_()
delta_ref = delta.detach().clone().requires_grad_()
delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
out, *rest = selective_scan_fn(
u, delta, A, B, C, D, z=z,
delta_bias=delta_bias, delta_softplus=delta_softplus,
return_last_state=return_last_state
)
if return_last_state:
state = rest[0]
state = None
state_ref = None
outs = []
for c in range(scan_chunks):
chunked_prompt_len = seqlen // scan_chunks
chunk_start = chunked_prompt_len * c
chunk_end = chunked_prompt_len * (c + 1)
if c == scan_chunks - 1:
chunk_end = seqlen
_B = B
if is_variable_B:
_B = B[...,chunk_start:chunk_end]
_C = C
if is_variable_B:
_C = C[...,chunk_start:chunk_end]
_z = z
if has_z:
_z = z[...,chunk_start:chunk_end]
out, *rest = selective_scan_fn(
u[...,chunk_start:chunk_end], delta[...,chunk_start:chunk_end], A, _B, _C, D, z=_z,
delta_bias=delta_bias, delta_softplus=delta_softplus,
return_last_state=return_last_state,prev_state=state if c > 0 else None
)
outs.append(out)
if return_last_state:
state = rest[0]
if len(outs) > 1:
out = torch.cat(outs,dim=-1)
out_ref, *rest = selective_scan_ref(
u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref,
delta_bias=delta_bias_ref, delta_softplus=delta_softplus,
Expand All @@ -115,6 +137,9 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z
if return_last_state:
print(f'State max diff: {(state - state_ref).abs().max().item()}')
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
if scan_chunks > 1:
# skip grad test in case of scan chunks ( not supported atm )
return

g = torch.randn_like(out)
out_ref.backward(g)
Expand Down