Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

selective_scan_cuda error #265

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions mamba_ssm/ops/selective_scan_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
causal_conv1d_fn = None
causal_conv1d_cuda = None

import selective_scan_cuda
try:
import selective_scan_cuda
except ImportError:
selective_scan_cuda = None


class SelectiveScanFn(torch.autograd.Function):
Expand Down Expand Up @@ -85,7 +88,10 @@ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
if selective_scan_cuda is None:
return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
else:
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)


def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
Expand Down Expand Up @@ -314,11 +320,17 @@ def mamba_inner_fn(
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True
):
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
if causal_conv1d_cuda is None:
return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
else:
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)



def mamba_inner_ref(
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
Expand Down