From 55a7ba5cfc01370b06aa16a35ba7399fedd20703 Mon Sep 17 00:00:00 2001 From: PabloEnriqueGF <51404192+PabloEnriqueGF@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:30:40 +0200 Subject: [PATCH 1/4] Update layer_norm.py --- mamba_ssm/ops/triton/layer_norm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mamba_ssm/ops/triton/layer_norm.py b/mamba_ssm/ops/triton/layer_norm.py index 2f304d43..af8db0d0 100755 --- a/mamba_ssm/ops/triton/layer_norm.py +++ b/mamba_ssm/ops/triton/layer_norm.py @@ -11,7 +11,7 @@ import torch import torch.nn.functional as F -from torch.cuda.amp import custom_fwd, custom_bwd +from torch.amp import custom_fwd, custom_bwd import triton import triton.language as tl @@ -982,7 +982,7 @@ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): class LayerNormLinearFn(torch.autograd.Function): @staticmethod - @custom_fwd + @custom_fwd(device_type="cuda") def forward( ctx, x, @@ -1041,7 +1041,7 @@ def forward( return out if not prenorm else (out, residual_out.reshape(x_shape_og)) @staticmethod - @custom_bwd + @custom_bwd(device_type="cuda") def backward(ctx, dout, *args): x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors dout = dout.reshape(-1, dout.shape[-1]) From 0de6142b16dd82bfbea79af8c6676fadd3582035 Mon Sep 17 00:00:00 2001 From: PabloEnriqueGF <51404192+PabloEnriqueGF@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:31:19 +0200 Subject: [PATCH 2/4] Update tensor_parallel.py --- mamba_ssm/distributed/tensor_parallel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mamba_ssm/distributed/tensor_parallel.py b/mamba_ssm/distributed/tensor_parallel.py index 3660abfc..2bce1ad6 100644 --- a/mamba_ssm/distributed/tensor_parallel.py +++ b/mamba_ssm/distributed/tensor_parallel.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd +from torch.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup from einops import rearrange @@ -22,7 +22,7 @@ class ParallelLinearFunc(torch.autograd.Function): @staticmethod - @custom_fwd + @custom_fwd(device_type="cuda) def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True): """ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel @@ -58,7 +58,7 @@ def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True): return output @staticmethod - @custom_bwd + @custom_bwd(device_type="cuda") def backward(ctx, grad_output): grad_output = grad_output.contiguous() process_group = ctx.process_group From ab8c6b0ba438e92990673a6d362ba2a7c0af8ef8 Mon Sep 17 00:00:00 2001 From: PabloEnriqueGF <51404192+PabloEnriqueGF@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:32:21 +0200 Subject: [PATCH 3/4] Update selective_scan_interface.py --- mamba_ssm/ops/selective_scan_interface.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index c3596bfe..79deb224 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F -from torch.cuda.amp import custom_bwd, custom_fwd +from torch.amp import custom_bwd, custom_fwd from einops import rearrange, repeat @@ -160,7 +160,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta class MambaInnerFn(torch.autograd.Function): @staticmethod - @custom_fwd + @custom_fwd(device_type="cuda") def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, @@ -236,7 +236,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod - @custom_bwd + @custom_bwd(device_type="cuda") def backward(ctx, dout): # dout: (batch, seqlen, dim) assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." From 53bd7d8482ec246d19656e9336c1071015014da4 Mon Sep 17 00:00:00 2001 From: PabloEnriqueGF <51404192+PabloEnriqueGF@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:33:15 +0200 Subject: [PATCH 4/4] Update ssd_combined.py --- mamba_ssm/ops/triton/ssd_combined.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index 77d20715..e6b695b2 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -11,7 +11,7 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd +from torch.amp import custom_bwd, custom_fwd import triton import triton.language as tl @@ -754,7 +754,7 @@ def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D= class MambaSplitConv1dScanCombinedFn(torch.autograd.Function): @staticmethod - @custom_fwd + @custom_fwd(device_type="cuda") def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True): @@ -832,7 +832,7 @@ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, return out if not return_final_states else (out, final_states) @staticmethod - @custom_bwd + @custom_bwd(device_type="cuda") def backward(ctx, dout, *args): zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors dfinal_states = args[0] if ctx.return_final_states else None