Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support variable-length sequences for mamba block with position indices #434

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions csrc/selective_scan/selective_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ void set_ssm_params_fwd(SSMParamsBase &params,
void* D_ptr,
void* delta_bias_ptr,
void* x_ptr,
bool has_z,
bool delta_softplus) {
bool has_z,
bool delta_softplus,
void* index_ptr) {

// Reset the parameters
memset(&params, 0, sizeof(params));
Expand Down Expand Up @@ -109,6 +110,9 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.x_ptr = x_ptr;
params.z_ptr = has_z ? z.data_ptr() : nullptr;
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;

params.index_ptr = index_ptr;

// All stride are in elements, not bytes.
params.A_d_stride = A.stride(0);
params.A_dstate_stride = A.stride(1);
Expand Down Expand Up @@ -173,15 +177,16 @@ void set_ssm_params_bwd(SSMParamsBwd &params,
void* ddelta_bias_ptr,
bool has_z,
bool delta_softplus,
bool recompute_out_z) {
bool recompute_out_z,
void* index_ptr) {
// Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
u, delta, A, B, C, has_z ? out : dout,
has_z ? z : dout,
// If not recompute_out_z, pass dout instead of out_z.
// This won't be used by the bwd kernel
recompute_out_z ? out_z : dout,
D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus, index_ptr);
if (!recompute_out_z) { params.out_z_ptr = nullptr; }

// Set the pointers and strides.
Expand Down Expand Up @@ -229,7 +234,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
const c10::optional<at::Tensor> &D_,
const c10::optional<at::Tensor> &z_,
const c10::optional<at::Tensor> &delta_bias_,
bool delta_softplus) {
bool delta_softplus,
const c10::optional<at::Tensor> &index_) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
Expand Down Expand Up @@ -292,6 +298,12 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
CHECK_SHAPE(delta_bias, dim);
}
if (index_.has_value()) {
auto index = index_.value();
TORCH_CHECK(index.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(index.is_cuda());
CHECK_SHAPE(index, batch_size, seqlen);
}

at::Tensor z, out_z;
const bool has_z = z_.has_value();
Expand Down Expand Up @@ -319,7 +331,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
x.data_ptr(),
has_z,
delta_softplus);
delta_softplus,
index_.has_value() ? index_.value().data_ptr() : nullptr);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
Expand All @@ -346,7 +359,8 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
const c10::optional<at::Tensor> &out_,
c10::optional<at::Tensor> &dz_,
bool delta_softplus,
bool recompute_out_z) {
bool recompute_out_z,
const c10::optional<at::Tensor> &index_) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
Expand Down Expand Up @@ -414,8 +428,15 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
CHECK_SHAPE(delta_bias, dim);
}

if (index_.has_value()) {
auto index = index_.value();
TORCH_CHECK(index.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(index.is_cuda());
CHECK_SHAPE(index, batch_size, seqlen);
}
at::Tensor z, out, dz, out_z;
const bool has_z = z_.has_value();

if (has_z) {
z = z_.value();
TORCH_CHECK(z.scalar_type() == input_type);
Expand Down Expand Up @@ -474,7 +495,8 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
dout, du, ddelta, dA, dB, dC, dz,
D_.has_value() ? dD.data_ptr() : nullptr,
delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
has_z, delta_softplus, recompute_out_z);
has_z, delta_softplus, recompute_out_z,
index_.has_value() ? index_.value().data_ptr() : nullptr);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
Expand Down
1 change: 1 addition & 0 deletions csrc/selective_scan/selective_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct SSMParamsBase {
void *__restrict__ x_ptr;
void *__restrict__ z_ptr;
void *__restrict__ out_z_ptr;
void *__restrict__ index_ptr;
};

struct SSMParamsBwd: public SSMParamsBase {
Expand Down
65 changes: 49 additions & 16 deletions csrc/selective_scan/selective_scan_bwd_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ template<> __device__ __forceinline__ float conj<float>(float x) { return x; }
template<> __device__ __forceinline__ complex_t conj<complex_t>(complex_t x) { return std::conj(x); }

template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kIsVariableB_, bool kIsVariableC_,
bool kDeltaSoftplus_, bool kHasZ_, typename input_t_, typename weight_t_>
bool kDeltaSoftplus_, bool kHasZ_, bool kUseIndex_,typename input_t_, typename weight_t_>
struct Selective_Scan_bwd_kernel_traits {
static_assert(kNItems_ % 4 == 0);
using input_t = input_t_;
Expand All @@ -42,13 +42,17 @@ struct Selective_Scan_bwd_kernel_traits {
static constexpr bool kIsVariableC = kIsVariableC_;
static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
static constexpr bool kHasZ = kHasZ_;
static constexpr bool kUseIndex = kUseIndex_;
static constexpr int kNLoadsIndex = kNItems / 4;
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
// For complex this would lead to massive register spilling, so we keep it at 2.
static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2;
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadIndexT = cub::BlockLoad<int, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadIndexVecT = cub::BlockLoad<uint4, kNThreads, kNLoadsIndex, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
Expand Down Expand Up @@ -80,6 +84,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
constexpr bool kHasZ = Ktraits::kHasZ;
constexpr bool kUseIndex = Ktraits::kUseIndex;
constexpr int kNThreads = Ktraits::kNThreads;
constexpr int kNItems = Ktraits::kNItems;
using input_t = typename Ktraits::input_t;
Expand All @@ -94,6 +99,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
auto& smem_load_index = reinterpret_cast<typename Ktraits::BlockLoadIndexT::TempStorage&>(smem_);
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
Expand Down Expand Up @@ -136,21 +142,30 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
: reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
float dD_val = 0;
float ddelta_bias_val = 0;

int *index = !kUseIndex ? nullptr :reinterpret_cast<int *>(params.index_ptr) + batch_id * params.seqlen;
constexpr int kChunkSize = kNThreads * kNItems;
u += (params.n_chunks - 1) * kChunkSize;
index += (params.n_chunks - 1) * kChunkSize;
delta += (params.n_chunks - 1) * kChunkSize;
dout += (params.n_chunks - 1) * kChunkSize;
Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
input_t u_vals[kNItems];

input_t delta_vals_load[kNItems];
input_t dout_vals_load[kNItems];
int index_vals_load[kNItems];
__syncthreads();
load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
u -= kChunkSize;
__syncthreads();
if constexpr (kUseIndex) {
load_index<Ktraits>(index, index_vals_load, smem_load_index, params.seqlen - chunk * kChunkSize);
index -= kChunkSize;
}
__syncthreads();

load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
// Will reload delta at the same location if kDeltaSoftplus
if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
Expand Down Expand Up @@ -244,8 +259,16 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
if constexpr (!kIsComplex) {
#pragma unroll
for (int i = 0; i < kNItems; ++i) {
const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
float delta_a_exp = exp2f(delta_vals[i] * A_scaled);

// Reset A bar for cumulative sequences (Real)
if constexpr (kUseIndex) {
if (index_vals_load[i] == 0) {
delta_a_exp = 0.f;
}
}
thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);

if (i == 0) {
smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
} else {
Expand Down Expand Up @@ -332,6 +355,14 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
for (int i = 0; i < kNItems; ++i) {
// Pytorch's implementation of complex exp (which calls thrust) is very slow
complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled);

// Reset A bar for cumulative sequences (Complex)
if constexpr (kUseIndex) {
if (index_vals_load[i] == 0) {
delta_a_exp.real_ = 0.f;
delta_a_exp.imag_ = 0.f;
}
}
weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]);
thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
if (i == 0) {
Expand Down Expand Up @@ -495,19 +526,21 @@ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
// using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
// TODO: check this
constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);
// printf("smem_size = %d\n", kSmemSize);
dim3 grid(params.batch, params.dim);
auto kernel = &selective_scan_bwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] {
using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, kUseIndex, input_t, weight_t>;
// using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
// TODO: check this
constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);
// printf("smem_size = %d\n", kSmemSize);
dim3 grid(params.batch, params.dim);
auto kernel = &selective_scan_bwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
Expand Down
16 changes: 16 additions & 0 deletions csrc/selective_scan/selective_scan_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,22 @@ inline __device__ void load_input(typename Ktraits::input_t *u,
}
}

template<typename Ktraits>
inline __device__ void load_index(int *u,
int (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index,
int seqlen) {
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_load_index_vec = reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(smem_load_index);
Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(
reinterpret_cast<uint4*>(u),
reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals)
);
} else {
Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
}
}

template<typename Ktraits>
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
Expand Down
Loading