diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index c3596bfe..503634b6 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -13,7 +13,10 @@ causal_conv1d_fn = None causal_conv1d_cuda = None -import selective_scan_cuda +try: + import selective_scan_cuda +except ImportError: + selective_scan_cuda = None class SelectiveScanFn(torch.autograd.Function): @@ -85,7 +88,10 @@ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_ last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + if selective_scan_cuda is None: + return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + else: + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, @@ -314,11 +320,17 @@ def mamba_inner_fn( A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): - return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + if causal_conv1d_cuda is None: + return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) + else: + return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) + def mamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias,