diff --git a/README.md b/README.md index 6ca031cd..aba5c2ee 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@ ![Mamba](assets/selection.png "Selective State Space") > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\ > Albert Gu*, Tri Dao*\ -> Paper: https://arxiv.org/abs/2312.00752 -> **Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality**\ +> Paper: https://arxiv.org/abs/2312.00752\ +> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality**\ > Tri Dao*, Albert Gu*\ > Paper: https://arxiv.org/abs/2405.21060 diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py index 64f4c0c4..0204ce2f 100644 --- a/mamba_ssm/__init__.py +++ b/mamba_ssm/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.0.0" +__version__ = "2.0.1" from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.modules.mamba_simple import Mamba diff --git a/mamba_ssm/distributed/tensor_parallel.py b/mamba_ssm/distributed/tensor_parallel.py index cc55e793..3660abfc 100644 --- a/mamba_ssm/distributed/tensor_parallel.py +++ b/mamba_ssm/distributed/tensor_parallel.py @@ -11,7 +11,7 @@ from einops import rearrange -from src.distributed.distributed_utils import ( +from mamba_ssm.distributed.distributed_utils import ( all_gather_raw, all_reduce, all_reduce_raw, diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py index ad3d5f5a..aefac53d 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_scan.py +++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py @@ -14,7 +14,7 @@ from einops import rearrange, repeat -from src.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd +from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index 58d806b2..b640ea6b 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -24,20 +24,20 @@ except ImportError: causal_conv1d_fn, causal_conv1d_cuda = None, None -from src.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd -from src.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd -from src.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db -from src.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable -from src.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref -from src.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd -from src.ops.triton.ssd_state_passing import state_passing, state_passing_ref -from src.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates -from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb -from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable -from src.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref -from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev -from src.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd -from src.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd +from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd +from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd +from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db +from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable +from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref +from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd +from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable +from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev +from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd +from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') @@ -651,7 +651,7 @@ def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus Return: out: (batch, seqlen, nheads, headdim) """ - from src.ops.selective_scan_interface import selective_scan_fn + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape