Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue 496 #501

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mamba_ssm/distributed/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup

from einops import rearrange
Expand All @@ -22,7 +22,7 @@

class ParallelLinearFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
@custom_fwd(device_type="cuda)
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
Expand Down Expand Up @@ -58,7 +58,7 @@ def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
process_group = ctx.process_group
Expand Down
6 changes: 3 additions & 3 deletions mamba_ssm/ops/selective_scan_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.amp import custom_bwd, custom_fwd

from einops import rearrange, repeat

Expand Down Expand Up @@ -160,7 +160,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta
class MambaInnerFn(torch.autograd.Function):

@staticmethod
@custom_fwd
@custom_fwd(device_type="cuda")
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
Expand Down Expand Up @@ -236,7 +236,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)

@staticmethod
@custom_bwd
@custom_bwd(device_type="cuda")
def backward(ctx, dout):
# dout: (batch, seqlen, dim)
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
Expand Down
6 changes: 3 additions & 3 deletions mamba_ssm/ops/triton/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch
import torch.nn.functional as F
from torch.cuda.amp import custom_fwd, custom_bwd
from torch.amp import custom_fwd, custom_bwd

import triton
import triton.language as tl
Expand Down Expand Up @@ -982,7 +982,7 @@ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):

class LayerNormLinearFn(torch.autograd.Function):
@staticmethod
@custom_fwd
@custom_fwd(device_type="cuda")
def forward(
ctx,
x,
Expand Down Expand Up @@ -1041,7 +1041,7 @@ def forward(
return out if not prenorm else (out, residual_out.reshape(x_shape_og))

@staticmethod
@custom_bwd
@custom_bwd(device_type="cuda")
def backward(ctx, dout, *args):
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
Expand Down
6 changes: 3 additions & 3 deletions mamba_ssm/ops/triton/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.amp import custom_bwd, custom_fwd

import triton
import triton.language as tl
Expand Down Expand Up @@ -754,7 +754,7 @@ def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D=
class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):

@staticmethod
@custom_fwd
@custom_fwd(device_type="cuda")
def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None,
ngroups=1, norm_before_gate=True):
Expand Down Expand Up @@ -832,7 +832,7 @@ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size,
return out if not return_final_states else (out, final_states)

@staticmethod
@custom_bwd
@custom_bwd(device_type="cuda")
def backward(ctx, dout, *args):
zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors
dfinal_states = args[0] if ctx.return_final_states else None
Expand Down