Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support variable-length sequences for mamba block #244

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
d28e1b0
add cu_seqlens support and ensure numerical equality
zigzagcai Mar 8, 2024
a78a9eb
add notes for variable length sequences
zigzagcai Mar 14, 2024
e223353
fix typos
zigzagcai Mar 15, 2024
5955450
fix typos
zigzagcai Mar 18, 2024
ca189f6
Merge branch 'main' into feat/add-cu_seqlens
zigzagcai Mar 18, 2024
c2d5b88
fix typos
Dmovic Mar 18, 2024
db0dd09
fix typos
zigzagcai Mar 18, 2024
842bef5
Merge branch 'main' into feat/add-cu_seqlens
zigzagcai Mar 18, 2024
e7774aa
refine cu_seqlens implementation
zigzagcai Mar 18, 2024
1ccc60f
Merge branch 'feat/add-cu_seqlens' into feat/add-cu_seqlens
Dmovic Mar 19, 2024
4bf2697
Merge pull request #1 from Dmovic/feat/add-cu_seqlens
zigzagcai Mar 19, 2024
f357c44
add unit test for variable length
Dmovic Mar 19, 2024
6b98161
update unit test
Dmovic Mar 19, 2024
e4af927
fix typos
zigzagcai Mar 19, 2024
4221d48
update selective scan
zigzagcai Mar 25, 2024
934c0e6
Add logic for variable-length sequences
wang-zerui Mar 25, 2024
63b646d
Merge branch 'main' into feat/add-cu_seqlens
zigzagcai Apr 18, 2024
f6bb7e2
add example test to prove the mathematical equivalence of cu_seqlens …
zigzagcai Apr 26, 2024
bffcd97
fix typos
zigzagcai Apr 26, 2024
e3cab98
add cu_seqlens support for MixerModel
zigzagcai Apr 26, 2024
2f01ede
code refine for tests
zigzagcai Apr 30, 2024
f0a6508
refine code for tests
zigzagcai Apr 30, 2024
623d246
update API notes
zigzagcai Apr 30, 2024
ef3f760
update test code
zigzagcai Apr 30, 2024
71c77b1
Merge remote-tracking branch 'origin/main' into feat/add-cu_seqlens
zigzagcai Jun 6, 2024
2d27ccc
fix conflicts with latest main branch
zigzagcai Jun 6, 2024
f802627
Merge remote-tracking branch 'origin/main' into feat/add-cu_seqlens
zigzagcai Jul 16, 2024
596943c
fix unittest for test_selective_state_update_with_heads
zigzagcai Jul 16, 2024
6961faa
Merge branch 'state-spaces:main' into feat/add-cu_seqlens
zigzagcai Jul 18, 2024
b69b957
migrate to tridao's native varlen causal_conv1d kernel for speedup
zigzagcai Jul 19, 2024
50bffae
Merge branch 'state-spaces:main' into feat/add-cu_seqlens
zigzagcai Jul 22, 2024
909f970
typo fix
zigzagcai Jul 23, 2024
8174c45
use seq_idx if provided, or compute it by cu_seqlens
zigzagcai Aug 5, 2024
59be631
use seq_idx if provided, or compute it by cu_seqlens
zigzagcai Aug 5, 2024
3bc4a51
Merge branch 'state-spaces:main' into feat/add-cu_seqlens
zigzagcai Aug 6, 2024
210b6f6
mv cu_seqlens in ssm kernel to smem
zigzagcai Aug 7, 2024
cda4b5a
remove smem implementation because const vals and bi-search is enough
zigzagcai Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions csrc/selective_scan/selective_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ void set_ssm_params_fwd(SSMParamsBase &params,
void* delta_bias_ptr,
void* x_ptr,
bool has_z,
bool delta_softplus) {
bool delta_softplus,
void* cu_seqlens_ptr,
const int cu_seqlens_size) {

// Reset the parameters
memset(&params, 0, sizeof(params));
Expand Down Expand Up @@ -109,6 +111,10 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.x_ptr = x_ptr;
params.z_ptr = has_z ? z.data_ptr() : nullptr;
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;

params.cu_seqlens_ptr = cu_seqlens_ptr;
params.cu_seqlens_size = cu_seqlens_size;

// All stride are in elements, not bytes.
params.A_d_stride = A.stride(0);
params.A_dstate_stride = A.stride(1);
Expand Down Expand Up @@ -173,15 +179,17 @@ void set_ssm_params_bwd(SSMParamsBwd &params,
void* ddelta_bias_ptr,
bool has_z,
bool delta_softplus,
bool recompute_out_z) {
bool recompute_out_z,
void* cu_seqlens_ptr,
const int cu_seqlens_size) {
// Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
u, delta, A, B, C, has_z ? out : dout,
has_z ? z : dout,
// If not recompute_out_z, pass dout instead of out_z.
// This won't be used by the bwd kernel
recompute_out_z ? out_z : dout,
D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus, cu_seqlens_ptr, cu_seqlens_size);
if (!recompute_out_z) { params.out_z_ptr = nullptr; }

// Set the pointers and strides.
Expand Down Expand Up @@ -229,7 +237,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
const c10::optional<at::Tensor> &D_,
const c10::optional<at::Tensor> &z_,
const c10::optional<at::Tensor> &delta_bias_,
bool delta_softplus) {
bool delta_softplus,
const c10::optional<at::Tensor> &cu_seqlens_) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
Expand Down Expand Up @@ -319,7 +328,9 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
x.data_ptr(),
has_z,
delta_softplus);
delta_softplus,
cu_seqlens_.has_value() ? cu_seqlens_.value().data_ptr() : nullptr,
cu_seqlens_.has_value() ? cu_seqlens_.value().size(0) : 0);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
Expand All @@ -346,7 +357,8 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
const c10::optional<at::Tensor> &out_,
c10::optional<at::Tensor> &dz_,
bool delta_softplus,
bool recompute_out_z) {
bool recompute_out_z,
const c10::optional<at::Tensor> &cu_seqlens_) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
Expand Down Expand Up @@ -474,7 +486,9 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
dout, du, ddelta, dA, dB, dC, dz,
D_.has_value() ? dD.data_ptr() : nullptr,
delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
has_z, delta_softplus, recompute_out_z);
has_z, delta_softplus, recompute_out_z,
cu_seqlens_.has_value() ? cu_seqlens_.value().data_ptr() : nullptr,
cu_seqlens_.has_value() ? cu_seqlens_.value().size(0) : 0);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
Expand Down
4 changes: 4 additions & 0 deletions csrc/selective_scan/selective_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct SSMParamsBase {

bool delta_softplus;

int cu_seqlens_size;

index_t A_d_stride;
index_t A_dstate_stride;
index_t B_batch_stride;
Expand Down Expand Up @@ -66,6 +68,8 @@ struct SSMParamsBase {
void *__restrict__ x_ptr;
void *__restrict__ z_ptr;
void *__restrict__ out_z_ptr;

void *__restrict__ cu_seqlens_ptr;
};

struct SSMParamsBwd: public SSMParamsBase {
Expand Down
41 changes: 40 additions & 1 deletion csrc/selective_scan/selective_scan_bwd_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
: reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
float dD_val = 0;
float ddelta_bias_val = 0;

const int cu_seqlens_size = params.cu_seqlens_size;
const long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr);

constexpr int kChunkSize = kNThreads * kNItems;
u += (params.n_chunks - 1) * kChunkSize;
Expand Down Expand Up @@ -250,8 +253,26 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
if constexpr (!kIsComplex) {
#pragma unroll
for (int i = 0; i < kNItems; ++i) {
const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
float delta_a_exp = exp2f(delta_vals[i] * A_scaled);

// Reset A bar for cumulative sequences (Real)
int left = 1;
int right = cu_seqlens_size - 2;
int idx = threadIdx.x * kNItems + i + chunk * kChunkSize;
while (left <= right) {
int mid = (left + right) >> 1;
if (cu_seqlens[mid] == idx) {
delta_a_exp = 0.f;
break;
} else if (cu_seqlens[mid] < idx) {
left = mid + 1;
} else {
right = mid - 1;
}
}

thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);

if (i == 0) {
smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
} else {
Expand Down Expand Up @@ -338,6 +359,24 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
for (int i = 0; i < kNItems; ++i) {
// Pytorch's implementation of complex exp (which calls thrust) is very slow
complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled);

// Reset A bar for cumulative sequences (Complex)
int left = 1;
int right = cu_seqlens_size - 2;
int idx = threadIdx.x * kNItems + i + chunk * kChunkSize;
while (left <= right) {
int mid = (left + right) >> 1;
if (cu_seqlens[mid] == idx) {
delta_a_exp.real_ = 0.f;
delta_a_exp.imag_ = 0.f;
break;
} else if (cu_seqlens[mid] < idx) {
left = mid + 1;
} else {
right = mid - 1;
}
}

weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]);
thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
if (i == 0) {
Expand Down
39 changes: 39 additions & 0 deletions csrc/selective_scan/selective_scan_fwd_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;

const int cu_seqlens_size = params.cu_seqlens_size;
const long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr);

float D_val[kNRows] = {0};
if (params.D_ptr != nullptr) {
Expand Down Expand Up @@ -220,6 +223,23 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (!kIsComplex) {
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);

// Reset A bar for cumulative sequences (Real)
int left = 1;
int right = cu_seqlens_size - 2;
int idx = threadIdx.x * kNItems + i + chunk * kChunkSize;
while (left <= right) {
int mid = (left + right) >> 1;
if (cu_seqlens[mid] == idx) {
thread_data[i].x = 0.f;
break;
} else if (cu_seqlens[mid] < idx) {
left = mid + 1;
} else {
right = mid - 1;
}
}

if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
thread_data[i] = make_float2(1.f, 0.f);
Expand All @@ -230,6 +250,25 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]);
weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i];
thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);

// Reset A bar for cumulative sequences (Complex)
int left = 1;
int right = cu_seqlens_size - 2;
int idx = threadIdx.x * kNItems + i + chunk * kChunkSize;
while (left <= right) {
int mid = (left + right) >> 1;
if (cu_seqlens[mid] == idx) {
thread_data[i].x = 0.f;
thread_data[i].y = 0.f;
break;
} else if (cu_seqlens[mid] < idx) {
left = mid + 1;
} else {
right = mid - 1;
}
}


if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f);
Expand Down
29 changes: 26 additions & 3 deletions mamba_ssm/modules/mamba_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,23 @@ def __init__(

self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)

def forward(self, hidden_states, inference_params=None):
def forward(self, hidden_states, cu_seqlens=None, seq_idx=None, inference_params=None):
"""
hidden_states: (B, L, D)
cu_seqlens: (Optional) cumulative sum of the sequence lengths, starting from 0 and end with L, and must already be sorted.
Returns: same shape as hidden_states
"""
batch, seqlen, dim = hidden_states.shape

if cu_seqlens is not None:
assert batch == 1 and cu_seqlens.ndimension() == 1, "varlen mamba1 is only supported with B=1"
# compute seq_idx if not provided
if seq_idx is None:
seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device)
for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0)
else:
seq_idx = None

conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
Expand Down Expand Up @@ -157,6 +167,8 @@ def forward(self, hidden_states, inference_params=None):
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,
)
else:
x, z = xz.chunk(2, dim=1)
Expand All @@ -166,13 +178,23 @@ def forward(self, hidden_states, inference_params=None):
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
if cu_seqlens is not None:
# naive pure python implementation of varlen causal_conv1d
for i, s in enumerate(cu_seqlens[1:-1]):
x = torch.cat((x[..., :s + i*(self.d_conv - 1)], torch.zeros_like(x[..., :(self.d_conv - 1)]), x[..., s + i*(self.d_conv - 1):]), dim=2)
mask = torch.cat([torch.cat((torch.full((s,), True, dtype=torch.bool, device=x.device),
torch.full((self.d_conv - 1,), False, dtype=torch.bool, device=x.device)), dim=0)
for s in (cu_seqlens[1:] - cu_seqlens[:-1])], dim=0)
x = self.act(self.conv1d(x)[:, :, mask])
else:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x=x,
x=x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
seq_idx=seq_idx,
activation=self.activation,
)

Expand All @@ -197,6 +219,7 @@ def forward(self, hidden_states, inference_params=None):
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=ssm_state is not None,
cu_seqlens=cu_seqlens,
)
if ssm_state is not None:
y, last_state = y
Expand Down
Loading