diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml new file mode 100644 index 00000000..990e1ff9 --- /dev/null +++ b/.github/workflows/publish.yaml @@ -0,0 +1,212 @@ +# This workflow will: +# - Create a new Github release +# - Build wheels for supported architectures +# - Deploy the wheels to the Github release +# - Release the static code to PyPi +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries + +name: Build wheels and deploy + +on: + create: + tags: + - v* + +jobs: + + setup_release: + name: Create Release + runs-on: ubuntu-latest + steps: + - name: Get the tag version + id: extract_branch + run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} + shell: bash + + - name: Create Release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ steps.extract_branch.outputs.branch }} + release_name: ${{ steps.extract_branch.outputs.branch }} + + build_wheels: + name: Build Wheel + needs: setup_release + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the + # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. + os: [ubuntu-20.04] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.1', '2.2.0.dev20231127'] + cuda-version: ['11.8.0', '12.2.0'] + # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. + # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. + # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) + # when building without C++11 ABI and using it on nvcr images. + cxx11_abi: ['FALSE', 'TRUE'] + exclude: + # Pytorch <= 1.12 does not support Python 3.11 + - torch-version: '1.12.1' + python-version: '3.11' + # Pytorch >= 2.0 only supports Python >= 3.8 + - torch-version: '2.0.1' + python-version: '3.7' + - torch-version: '2.1.1' + python-version: '3.7' + - torch-version: '2.2.0.dev20231127' + python-version: '3.7' + # Pytorch <= 2.0 only supports CUDA <= 11.8 + - torch-version: '1.12.1' + cuda-version: '12.2.0' + - torch-version: '1.13.1' + cuda-version: '12.2.0' + - torch-version: '2.0.1' + cuda-version: '12.2.0' + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Set CUDA and PyTorch versions + run: | + echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV + + - name: Free up disk space + if: ${{ runner.os == 'Linux' }} + # https://github.com/easimon/maximize-build-space/blob/master/action.yml + # https://github.com/easimon/maximize-build-space/tree/test-report + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + + - name: Set up swap space + if: runner.os == 'Linux' + uses: pierotofy/set-swap-space@v1.0 + with: + swap-size-gb: 10 + + - name: Install CUDA ${{ matrix.cuda-version }} + if: ${{ matrix.cuda-version != 'cpu' }} + uses: Jimver/cuda-toolkit@v0.2.11 + id: cuda-toolkit + with: + cuda: ${{ matrix.cuda-version }} + linux-local-args: '["--toolkit"]' + # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 + # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }} + method: 'network' + # We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions, + # not just nvcc + # sub-packages: '["nvcc"]' + + - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} + run: | + pip install --upgrade pip + # If we don't install before installing Pytorch, we get error for torch 2.0.1 + # ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none) + pip install lit + # We want to figure out the CUDA version to download pytorch + # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 + # This code is ugly, maybe there's a better way to do this. + export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))") + if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then + pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} + else + pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} + fi + nvcc --version + python --version + python -c "import torch; print('PyTorch:', torch.__version__)" + python -c "import torch; print('CUDA:', torch.version.cuda)" + python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" + shell: + bash + + - name: Build wheel + run: | + # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 + # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 + # However this still fails so I'm using a newer version of setuptools + pip install setuptools==68.0.0 + pip install ninja packaging wheel + export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH + export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH + # Limit MAX_JOBS otherwise the github runner goes OOM + MAX_JOBS=2 MAMBA_FORCE_BUILD="TRUE" MAMBA_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist + tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }} + wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") + ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} + echo "wheel_name=${wheel_name}" >> $GITHUB_ENV + + - name: Log Built Wheels + run: | + ls dist + + - name: Get the tag version + id: extract_branch + run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} + + - name: Get Release with tag + id: get_current_release + uses: joutvhu/get-release@v1 + with: + tag_name: ${{ steps.extract_branch.outputs.branch }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Upload Release Asset + id: upload_release_asset + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.get_current_release.outputs.upload_url }} + asset_path: ./dist/${{env.wheel_name}} + asset_name: ${{env.wheel_name}} + asset_content_type: application/* + + publish_package: + name: Publish package + needs: [build_wheels] + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + pip install ninja packaging setuptools wheel twine + # We don't want to download anything CUDA-related here + pip install torch --index-url https://download.pytorch.org/whl/cpu + + - name: Build core package + env: + MAMBA_SKIP_CUDA_BUILD: "TRUE" + run: | + python setup.py sdist --dist-dir=dist + + - name: Deploy + env: + TWINE_USERNAME: "__token__" + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + run: | + python -m twine upload dist/* diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..a7445800 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "3rdparty/lm-evaluation-harness"] + path = 3rdparty/lm-evaluation-harness + url = https://github.com/EleutherAI/lm-evaluation-harness/ diff --git a/3rdparty/lm-evaluation-harness b/3rdparty/lm-evaluation-harness new file mode 160000 index 00000000..a3520619 --- /dev/null +++ b/3rdparty/lm-evaluation-harness @@ -0,0 +1 @@ +Subproject commit a35206191acac1776761e737b66e0d04975d21b9 diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 00000000..38557a87 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,2 @@ +Tri Dao, tri@tridao.me +Albert Gu, agu@andrew.cmu.edu diff --git a/README.md b/README.md index bdc06251..7a957e8f 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,141 @@ # Mamba -This repository contains the code for the paper [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752). +![Mamba](assets/selection.png "Selective State Space") +> **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\ +> Albert Gu*, Tri Dao*\ +> Paper: https://arxiv.org/abs/2312.00752 -The first official code release of the paper will be uploaded around noon EST, Monday Dec. 4. +## Installation + +- `pip install causal-conv1d`: an efficient implemention of a simple causal Conv1d layer used inside the Mamba block. +- `pip install mamba-ssm`: the core Mamba package. + +If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`. + +Other requirements: +- Linux +- NVIDIA GPU +- PyTorch 1.12+ +- CUDA 11.6+ + +## Usage + +We expose several levels of interface with the Mamba model. + +### Selective SSM + +Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2). + +Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py). + +### Mamba Block + +The main module of this repository is the Mamba architecture block wrapping the selective SSM. + +Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py). + +Usage: +``` +from mamba_ssm import Mamba + +batch, length, dim = 2, 64, 16 +x = torch.randn(batch, length, dim).to("cuda") +model = Mamba( + # This module uses roughly 3 * expand * d_model^2 parameters + d_model=dim, # Model dimension d_model + d_state=16, # SSM state expansion factor + d_conv=4, # Local convolution width + expand=2, # Block expansion factor +).to("cuda") +y = model(x) +assert y.shape == x.shape +``` + +### Mamba Language Model + +Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head. + +Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py). + +This is an example of how to integrate Mamba into an end-to-end neural network. +This example is used in the generation scripts below. + + + +## Pretrained Models + +Pretrained models are uploaded to +[HuggingFace](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`, +`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`. + +The models will be autodownloaded by the generation script below. + +These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models: + +| Parameters | Layers | Model dim. | +|------------|--------|------------| +| 130M | 12 | 768 | +| 370M | 24 | 1024 | +| 790M | 24 | 1536 | +| 1.4B | 24 | 2048 | +| 2.8B | 32 | 2560 | + +(The layer count of Mamba should be doubled, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.) + +Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.). +Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models. + + +## Evaluations + +To run zero-shot evaluations of models (corresponding to Table 3 of the paper), +we use the +[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) +library. + +1. Pull the `lm-evaluation-harness` repo by `git submodule update --init + --recursive`. We use the `big-refactor` branch. +2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness` +3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo): +``` +python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 +python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 +``` + +Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process. + +## Inference + +The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py) +1. autoloads a model from the HuggingFace Hub, +2. generates completions of a user-specified prompt, +3. benchmarks the inference speed of this generation. + +Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature. + +### Examples + +To test generation latency (e.g. batch size = 1) with different sampling strategies: + +``` +python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5 +python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5 +``` + +To test generation throughput with random prompts (e.g. large batch size): +``` +python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128 +python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128 +``` + +## Citation + +If you use this codebase, or otherwise found our work valuable, please cite Mamba: +``` +@article{mamba, + title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces}, + author={Gu, Albert and Dao, Tri}, + journal={arXiv preprint arXiv:2312.00752}, + year={2023} +} +``` diff --git a/assets/selection.png b/assets/selection.png new file mode 100644 index 00000000..69b109a8 Binary files /dev/null and b/assets/selection.png differ diff --git a/benchmarks/benchmark_generation_mamba_simple.py b/benchmarks/benchmark_generation_mamba_simple.py new file mode 100644 index 00000000..8ff1b77f --- /dev/null +++ b/benchmarks/benchmark_generation_mamba_simple.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023, Tri Dao, Albert Gu. + +import argparse +import time +import json + +import torch +import torch.nn.functional as F + +from einops import rearrange + +from transformers import AutoTokenizer, AutoModelForCausalLM + +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + + +parser = argparse.ArgumentParser(description="Generation benchmarking") +parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m") +parser.add_argument("--prompt", type=str, default=None) +parser.add_argument("--promptlen", type=int, default=100) +parser.add_argument("--genlen", type=int, default=100) +parser.add_argument("--temperature", type=float, default=1.0) +parser.add_argument("--topk", type=int, default=1) +parser.add_argument("--topp", type=float, default=1.0) +parser.add_argument("--batch", type=int, default=1) +args = parser.parse_args() + +repeats = 3 +device = "cuda" +dtype = torch.float16 + +print(f"Loading model {args.model_name}") +is_mamba = args.model_name.startswith("state-spaces/mamba-") +if is_mamba: + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype) +else: + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype) +model.eval() +print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + +torch.random.manual_seed(0) +if args.prompt is None: + input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda") + attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda") +else: + tokens = tokenizer(args.prompt, return_tensors="pt") + input_ids = tokens.input_ids.to(device=device) + attn_mask = tokens.attention_mask.to(device=device) +max_length = input_ids.shape[1] + args.genlen + +if is_mamba: + fn = lambda: model.generate( + input_ids=input_ids, + max_length=max_length, + cg=True, + return_dict_in_generate=True, + output_scores=True, + enable_timing=False, + temperature=args.temperature, + top_k=args.topk, + top_p=args.topp, + ) +else: + fn = lambda: model.generate( + input_ids=input_ids, + attention_mask=attn_mask, + max_length=max_length, + return_dict_in_generate=True, + pad_token_id=tokenizer.eos_token_id, + do_sample=True, + temperature=args.temperature, + top_k=args.topk, + top_p=args.topp, + ) +out = fn() +if args.prompt is not None: + print(tokenizer.batch_decode(out.sequences.tolist())) + +torch.cuda.synchronize() +start = time.time() +for _ in range(repeats): + fn() +torch.cuda.synchronize() +print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}") +print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms") diff --git a/csrc/selective_scan/reverse_scan.cuh b/csrc/selective_scan/reverse_scan.cuh new file mode 100644 index 00000000..d7e93174 --- /dev/null +++ b/csrc/selective_scan/reverse_scan.cuh @@ -0,0 +1,401 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include +// #include +#include "uninitialized_copy.cuh" + +/** + * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ReductionOp> +__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) { + static_assert(LENGTH > 0); + T retval = input[LENGTH - 1]; + #pragma unroll + for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); } + return retval; +} + +/** + * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadReverseScanInclusive( + const T (&input)[LENGTH], + T (&output)[LENGTH], + ScanOp scan_op, + const T postfix) +{ + T inclusive = postfix; + #pragma unroll + for (int i = LENGTH - 1; i >= 0; --i) { + inclusive = scan_op(inclusive, input[i]); + output[i] = inclusive; + } +} + +/** + * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadReverseScanExclusive( + const T (&input)[LENGTH], + T (&output)[LENGTH], + ScanOp scan_op, + const T postfix) +{ + // Careful, output maybe be aliased to input + T exclusive = postfix; + T inclusive; + #pragma unroll + for (int i = LENGTH - 1; i >= 0; --i) { + inclusive = scan_op(exclusive, input[i]); + output[i] = exclusive; + exclusive = inclusive; + } + return inclusive; +} + + +/** + * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp. + * + * LOGICAL_WARP_THREADS must be a power-of-two + */ +template < + typename T, ///< Data type being scanned + int LOGICAL_WARP_THREADS ///< Number of threads per logical warp + > +struct WarpReverseScan { + //--------------------------------------------------------------------- + // Constants and type definitions + //--------------------------------------------------------------------- + + /// Whether the logical warp size and the PTX warp size coincide + static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0)); + /// The number of warp scan steps + static constexpr int STEPS = cub::Log2::VALUE; + static_assert(LOGICAL_WARP_THREADS == 1 << STEPS); + + + //--------------------------------------------------------------------- + // Thread fields + //--------------------------------------------------------------------- + + /// Lane index in logical warp + unsigned int lane_id; + + /// Logical warp index in 32-thread physical warp + unsigned int warp_id; + + /// 32-thread physical warp member mask of logical warp + unsigned int member_mask; + + //--------------------------------------------------------------------- + // Construction + //--------------------------------------------------------------------- + + /// Constructor + explicit __device__ __forceinline__ + WarpReverseScan() + : lane_id(cub::LaneId()) + , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS)) + , member_mask(cub::WarpMask(warp_id)) + { + if (!IS_ARCH_WARP) { + lane_id = lane_id % LOGICAL_WARP_THREADS; + } + } + + + /// Broadcast + __device__ __forceinline__ T Broadcast( + T input, ///< [in] The value to broadcast + int src_lane) ///< [in] Which warp lane is to do the broadcasting + { + return cub::ShuffleIndex(input, src_lane, member_mask); + } + + + /// Inclusive scan + template + __device__ __forceinline__ void InclusiveReverseScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOpT scan_op) ///< [in] Binary scan operator + { + inclusive_output = input; + #pragma unroll + for (int STEP = 0; STEP < STEPS; STEP++) { + int offset = 1 << STEP; + T temp = cub::ShuffleDown( + inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask + ); + // Perform scan op if from a valid peer + inclusive_output = static_cast(lane_id) >= LOGICAL_WARP_THREADS - offset + ? inclusive_output : scan_op(temp, inclusive_output); + } + } + + /// Exclusive scan + // Get exclusive from inclusive + template + __device__ __forceinline__ void ExclusiveReverseScan( + T input, ///< [in] Calling thread's input item. + T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOpT scan_op, ///< [in] Binary scan operator + T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. + { + T inclusive_output; + InclusiveReverseScan(input, inclusive_output, scan_op); + warp_aggregate = cub::ShuffleIndex(inclusive_output, 0, member_mask); + // initial value unknown + exclusive_output = cub::ShuffleDown( + inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask + ); + } + + /** + * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last warp-lane is undefined. + */ + template + __device__ __forceinline__ void ReverseScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. + T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. + ScanOpT scan_op) ///< [in] Binary scan operator + { + InclusiveReverseScan(input, inclusive_output, scan_op); + // initial value unknown + exclusive_output = cub::ShuffleDown( + inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask + ); + } + +}; + +/** + * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block. + */ +template < + typename T, ///< Data type being scanned + int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension + bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure + > +struct BlockReverseScan { + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// Constants + /// The thread block size in threads + static constexpr int BLOCK_THREADS = BLOCK_DIM_X; + + /// Layout type for padded thread block raking grid + using BlockRakingLayout = cub::BlockRakingLayout; + // The number of reduction elements is not a multiple of the number of raking threads for now + static_assert(BlockRakingLayout::UNGUARDED); + + /// Number of raking threads + static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS; + /// Number of raking elements per warp synchronous raking thread + static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH; + /// Cooperative work can be entirely warp synchronous + static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS)); + + /// WarpReverseScan utility type + using WarpReverseScan = WarpReverseScan; + + /// Shared memory storage layout type + struct _TempStorage { + typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid + }; + + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : cub::Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + // Thread fields + _TempStorage &temp_storage; + unsigned int linear_tid; + T cached_segment[SEGMENT_LENGTH]; + + + //--------------------------------------------------------------------- + // Utility methods + //--------------------------------------------------------------------- + + /// Performs upsweep raking reduction, returning the aggregate + template + __device__ __forceinline__ T Upsweep(ScanOp scan_op) { + T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + // Read data into registers + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } + T raking_partial = cached_segment[SEGMENT_LENGTH - 1]; + #pragma unroll + for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) { + raking_partial = scan_op(raking_partial, cached_segment[i]); + } + return raking_partial; + } + + + /// Performs exclusive downsweep raking scan + template + __device__ __forceinline__ void ExclusiveDownsweep( + ScanOp scan_op, + T raking_partial) + { + T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + // Read data back into registers + if (!MEMOIZE) { + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } + } + ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial); + // Write data back to smem + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; } + } + + + //--------------------------------------------------------------------- + // Constructors + //--------------------------------------------------------------------- + + /// Constructor + __device__ __forceinline__ BlockReverseScan( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1)) + {} + + + /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template < + typename ScanOp, + typename BlockPostfixCallbackOp> + __device__ __forceinline__ void ExclusiveReverseScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide postfix to be applied to all inputs. + { + if (WARP_SYNCHRONOUS) { + // Short-circuit directly to warp-synchronous scan + T block_aggregate; + WarpReverseScan warp_scan; + warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate); + // Obtain warp-wide postfix in lane0, then broadcast to other lanes + T block_postfix = block_postfix_callback_op(block_aggregate); + block_postfix = warp_scan.Broadcast(block_postfix, 0); + exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output); + } else { + // Place thread partial into shared memory raking grid + T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); + detail::uninitialized_copy(placement_ptr, input); + cub::CTA_SYNC(); + // Reduce parallelism down to just raking threads + if (linear_tid < RAKING_THREADS) { + WarpReverseScan warp_scan; + // Raking upsweep reduction across shared partials + T upsweep_partial = Upsweep(scan_op); + // Warp-synchronous scan + T exclusive_partial, block_aggregate; + warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate); + // Obtain block-wide postfix in lane0, then broadcast to other lanes + T block_postfix = block_postfix_callback_op(block_aggregate); + block_postfix = warp_scan.Broadcast(block_postfix, 0); + // Update postfix with warpscan exclusive partial + T downsweep_postfix = linear_tid == RAKING_THREADS - 1 + ? block_postfix : scan_op(block_postfix, exclusive_partial); + // Exclusive raking downsweep scan + ExclusiveDownsweep(scan_op, downsweep_postfix); + } + cub::CTA_SYNC(); + // Grab thread postfix from shared memory + exclusive_output = *placement_ptr; + + // // Compute warp scan in each warp. + // // The exclusive output from the last lane in each warp is invalid. + // T inclusive_output; + // WarpReverseScan warp_scan; + // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op); + + // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid. + // T block_aggregate; + // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate); + + // // Apply warp postfix to our lane's partial + // if (warp_id != 0) { + // exclusive_output = scan_op(warp_postfix, exclusive_output); + // if (lane_id == 0) { exclusive_output = warp_postfix; } + // } + + // // Use the first warp to determine the thread block postfix, returning the result in lane0 + // if (warp_id == 0) { + // T block_postfix = block_postfix_callback_op(block_aggregate); + // if (lane_id == 0) { + // // Share the postfix with all threads + // detail::uninitialized_copy(&temp_storage.block_postfix, + // block_postfix); + + // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0 + // } + // } + + // cub::CTA_SYNC(); + + // // Incorporate thread block postfix into outputs + // T block_postfix = temp_storage.block_postfix; + // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); } + } + } + + + /** + * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + */ + template < + int ITEMS_PER_THREAD, + typename ScanOp, + typename BlockPostfixCallbackOp> + __device__ __forceinline__ void InclusiveReverseScan( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan functor + BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence. + { + // Reduce consecutive thread items in registers + T thread_postfix = ThreadReverseReduce(input, scan_op); + // Exclusive thread block-scan + ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op); + // Inclusive scan in registers with postfix as seed + ThreadReverseScanInclusive(input, output, scan_op, thread_postfix); + } + +}; \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan.cpp b/csrc/selective_scan/selective_scan.cpp new file mode 100644 index 00000000..f51af402 --- /dev/null +++ b/csrc/selective_scan/selective_scan.cpp @@ -0,0 +1,497 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include +#include + +#include "selective_scan.h" + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Half) { \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::BFloat16) { \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::ComplexFloat) { \ + using weight_t = c10::complex; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + const bool is_variable_B, + const bool is_variable_C, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor out, + const at::Tensor z, + const at::Tensor out_z, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + bool has_z, + bool delta_softplus) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.n_chunks = n_chunks; + params.dim_ngroups_ratio = dim / n_groups; + + params.delta_softplus = delta_softplus; + + params.is_variable_B = is_variable_B; + params.is_variable_C = is_variable_C; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D_ptr; + params.delta_bias_ptr = delta_bias_ptr; + params.out_ptr = out.data_ptr(); + params.x_ptr = x_ptr; + params.z_ptr = has_z ? z.data_ptr() : nullptr; + params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + if (!is_variable_B) { + params.B_d_stride = B.stride(0); + } else { + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + } + params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); + if (!is_variable_C) { + params.C_d_stride = C.stride(0); + } else { + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + } + params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + if (has_z) { + params.z_batch_stride = z.stride(0); + params.z_d_stride = z.stride(1); + params.out_z_batch_stride = out_z.stride(0); + params.out_z_d_stride = out_z.stride(1); + } + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); +} + +void set_ssm_params_bwd(SSMParamsBwd ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + const bool is_variable_B, + const bool is_variable_C, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor z, + const at::Tensor out, + const at::Tensor out_z, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + const at::Tensor dout, + const at::Tensor du, + const at::Tensor ddelta, + const at::Tensor dA, + const at::Tensor dB, + const at::Tensor dC, + const at::Tensor dz, + void* dD_ptr, + void* ddelta_bias_ptr, + bool has_z, + bool delta_softplus, + bool recompute_out_z) { + // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z + set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, has_z ? out : dout, + has_z ? z : dout, + // If not recompute_out_z, pass dout instead of out_z. + // This won't be used by the bwd kernel + recompute_out_z ? out_z : dout, + D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus); + if (!recompute_out_z) { params.out_z_ptr = nullptr; } + + // Set the pointers and strides. + params.dout_ptr = dout.data_ptr(); + params.du_ptr = du.data_ptr(); + params.dA_ptr = dA.data_ptr(); + params.dB_ptr = dB.data_ptr(); + params.dC_ptr = dC.data_ptr(); + params.dD_ptr = dD_ptr; + params.ddelta_ptr = ddelta.data_ptr(); + params.ddelta_bias_ptr = ddelta_bias_ptr; + params.dz_ptr = has_z ? dz.data_ptr() : nullptr; + // All stride are in elements, not bytes. + params.dout_batch_stride = dout.stride(0); + params.dout_d_stride = dout.stride(1); + params.dA_d_stride = dA.stride(0); + params.dA_dstate_stride = dA.stride(1); + if (!is_variable_B) { + params.dB_d_stride = dB.stride(0); + } else { + params.dB_batch_stride = dB.stride(0); + params.dB_group_stride = dB.stride(1); + } + params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2); + if (!is_variable_C) { + params.dC_d_stride = dC.stride(0); + } else { + params.dC_batch_stride = dC.stride(0); + params.dC_group_stride = dC.stride(1); + } + params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2); + params.du_batch_stride = du.stride(0); + params.du_d_stride = du.stride(1); + params.ddelta_batch_stride = ddelta.stride(0); + params.ddelta_d_stride = ddelta.stride(1); + if (has_z) { + params.dz_batch_stride = dz.stride(0); + params.dz_d_stride = dz.stride(1); + } +} + +std::vector +selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + bool delta_softplus) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + const bool is_complex = weight_type == at::ScalarType::ComplexFloat; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = is_variable_B ? B.size(1) : 1; + + TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + if (!is_variable_B) { + CHECK_SHAPE(B, dim, dstate); + } else { + CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); + TORCH_CHECK(B.stride(-1) == 1); + } + if (!is_variable_C) { + CHECK_SHAPE(C, dim, dstate); + } else { + CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); + TORCH_CHECK(C.stride(-1) == 1); + } + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + at::Tensor z, out_z; + const bool has_z = z_.has_value(); + if (has_z) { + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + out_z = torch::empty_like(z); + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + // at::Tensor out = torch::empty_like(u); + // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout + at::Tensor out = torch::empty_like(delta); + at::Tensor x; + x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, out, z, out_z, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x.data_ptr(), + has_z, + delta_softplus); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { + DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] { + selective_scan_fwd_cuda(params, stream); + }); + }); + std::vector result = {out, x}; + if (has_z) { result.push_back(out_z); } + return result; +} + +std::vector +selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + const at::Tensor &dout, + const c10::optional &x_, + const c10::optional &out_, + c10::optional &dz_, + bool delta_softplus, + bool recompute_out_z) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + const bool is_complex = weight_type == at::ScalarType::ComplexFloat; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + TORCH_CHECK(dout.scalar_type() == input_type); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + TORCH_CHECK(dout.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1); + TORCH_CHECK(dout.stride(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = is_variable_B ? B.size(1) : 1; + + TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + if (!is_variable_B) { + CHECK_SHAPE(B, dim, dstate); + } else { + CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); + TORCH_CHECK(B.stride(-1) == 1); + } + if (!is_variable_C) { + CHECK_SHAPE(C, dim, dstate); + } else { + CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); + TORCH_CHECK(C.stride(-1) == 1); + } + CHECK_SHAPE(dout, batch_size, dim, seqlen); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + at::Tensor z, out, dz, out_z; + const bool has_z = z_.has_value(); + if (has_z) { + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + + TORCH_CHECK(out_.has_value()); + out = out_.value(); + TORCH_CHECK(out.scalar_type() == input_type); + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(out.stride(-1) == 1); + CHECK_SHAPE(out, batch_size, dim, seqlen); + + if (dz_.has_value()) { + dz = dz_.value(); + TORCH_CHECK(dz.scalar_type() == input_type); + TORCH_CHECK(dz.is_cuda()); + TORCH_CHECK(dz.stride(-1) == 1); + CHECK_SHAPE(dz, batch_size, dim, seqlen); + } else { + dz = torch::empty_like(z); + } + if (recompute_out_z) { + out_z = torch::empty_like(out); + } + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); } + if (x_.has_value()) { + auto x = x_.value(); + TORCH_CHECK(x.scalar_type() == weight_type); + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.is_contiguous()); + CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate); + } + + at::Tensor du = torch::empty_like(u); + at::Tensor ddelta = torch::empty_like(delta); + at::Tensor dA = torch::zeros_like(A); + at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32)); + at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32)); + at::Tensor dD; + if (D_.has_value()) { dD = torch::zeros_like(D_.value()); } + at::Tensor ddelta_bias; + if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); } + + SSMParamsBwd params; + set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, z, out, out_z, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x_.has_value() ? x_.value().data_ptr() : nullptr, + dout, du, ddelta, dA, dB, dC, dz, + D_.has_value() ? dD.data_ptr() : nullptr, + delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, + has_z, delta_softplus, recompute_out_z); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { + DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] { + selective_scan_bwd_cuda(params, stream); + }); + }); + std::vector result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias}; + if (has_z) { result.push_back(dz); } + if (recompute_out_z) { result.push_back(out_z); } + return result; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fwd", &selective_scan_fwd, "Selective scan forward"); + m.def("bwd", &selective_scan_bwd, "Selective scan backward"); +} diff --git a/csrc/selective_scan/selective_scan.h b/csrc/selective_scan/selective_scan.h new file mode 100644 index 00000000..e2c7bcdb --- /dev/null +++ b/csrc/selective_scan/selective_scan.h @@ -0,0 +1,101 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMScanParamsBase { + using index_t = uint32_t; + + int batch, seqlen, n_chunks; + index_t a_batch_stride; + index_t b_batch_stride; + index_t out_batch_stride; + + // Common data pointers. + void *__restrict__ a_ptr; + void *__restrict__ b_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, dstate, n_groups, n_chunks; + int dim_ngroups_ratio; + bool is_variable_B; + bool is_variable_C; + + bool delta_softplus; + + index_t A_d_stride; + index_t A_dstate_stride; + index_t B_batch_stride; + index_t B_d_stride; + index_t B_dstate_stride; + index_t B_group_stride; + index_t C_batch_stride; + index_t C_d_stride; + index_t C_dstate_stride; + index_t C_group_stride; + index_t u_batch_stride; + index_t u_d_stride; + index_t delta_batch_stride; + index_t delta_d_stride; + index_t z_batch_stride; + index_t z_d_stride; + index_t out_batch_stride; + index_t out_d_stride; + index_t out_z_batch_stride; + index_t out_z_d_stride; + + // Common data pointers. + void *__restrict__ A_ptr; + void *__restrict__ B_ptr; + void *__restrict__ C_ptr; + void *__restrict__ D_ptr; + void *__restrict__ u_ptr; + void *__restrict__ delta_ptr; + void *__restrict__ delta_bias_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; + void *__restrict__ z_ptr; + void *__restrict__ out_z_ptr; +}; + +struct SSMParamsBwd: public SSMParamsBase { + index_t dout_batch_stride; + index_t dout_d_stride; + index_t dA_d_stride; + index_t dA_dstate_stride; + index_t dB_batch_stride; + index_t dB_group_stride; + index_t dB_d_stride; + index_t dB_dstate_stride; + index_t dC_batch_stride; + index_t dC_group_stride; + index_t dC_d_stride; + index_t dC_dstate_stride; + index_t du_batch_stride; + index_t du_d_stride; + index_t dz_batch_stride; + index_t dz_d_stride; + index_t ddelta_batch_stride; + index_t ddelta_d_stride; + + // Common data pointers. + void *__restrict__ dout_ptr; + void *__restrict__ dA_ptr; + void *__restrict__ dB_ptr; + void *__restrict__ dC_ptr; + void *__restrict__ dD_ptr; + void *__restrict__ du_ptr; + void *__restrict__ dz_ptr; + void *__restrict__ ddelta_ptr; + void *__restrict__ ddelta_bias_ptr; +}; diff --git a/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu b/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu new file mode 100644 index 00000000..c55f0e85 --- /dev/null +++ b/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_bf16_real.cu b/csrc/selective_scan/selective_scan_bwd_bf16_real.cu new file mode 100644 index 00000000..72adaf5c --- /dev/null +++ b/csrc/selective_scan/selective_scan_bwd_bf16_real.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu b/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu new file mode 100644 index 00000000..df126d7c --- /dev/null +++ b/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_fp16_real.cu b/csrc/selective_scan/selective_scan_bwd_fp16_real.cu new file mode 100644 index 00000000..3ff271b5 --- /dev/null +++ b/csrc/selective_scan/selective_scan_bwd_fp16_real.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu b/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu new file mode 100644 index 00000000..55549023 --- /dev/null +++ b/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_fp32_real.cu b/csrc/selective_scan/selective_scan_bwd_fp32_real.cu new file mode 100644 index 00000000..a7ed6422 --- /dev/null +++ b/csrc/selective_scan/selective_scan_bwd_fp32_real.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh new file mode 100644 index 00000000..2ed10114 --- /dev/null +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -0,0 +1,531 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include // For atomicAdd on complex + +#include +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "reverse_scan.cuh" +#include "static_switch.h" + +template __device__ __forceinline__ scalar_t conj(scalar_t x); +template<> __device__ __forceinline__ float conj(float x) { return x; } +template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } + +template +struct Selective_Scan_bwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kNItems = kNItems_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsComplex = std::is_same_v; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; + static constexpr bool kHasZ = kHasZ_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. + // For complex this would lead to massive register spilling, so we keep it at 2. + static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; + using vec_t = typename BytesToType::Type; + using scan_t = std::conditional_t; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockReverseScanT = BlockReverseScan; + using BlockReduceT = cub::BlockReduce; + using BlockReduceFloatT = cub::BlockReduce; + using BlockReduceComplexT = cub::BlockReduce; + using BlockExchangeT = cub::BlockExchange; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); + static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_bwd_kernel(SSMParamsBwd params) { + constexpr bool kIsComplex = Ktraits::kIsComplex; + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); + auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); + auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); + auto& smem_reduce_complex = *reinterpret_cast(&smem_reduce); + auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); + auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); + weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + 2 * MAX_DSTATE + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + MAX_DSTATE); + weight_t *smem_dbc = reinterpret_cast(smem_da + MAX_DSTATE); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * params.delta_d_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + dim_id * params.dout_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * params.dA_d_stride; + weight_t *dB = reinterpret_cast(params.dB_ptr) + + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); + weight_t *dC = reinterpret_cast(params.dC_ptr) + + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); + float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id; + float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast(params.D_ptr)[dim_id]; + float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id; + float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast(params.delta_bias_ptr)[dim_id]; + scan_t *x = params.x_ptr == nullptr + ? nullptr + : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; + float dD_val = 0; + float ddelta_bias_val = 0; + + constexpr int kChunkSize = kNThreads * kNItems; + u += (params.n_chunks - 1) * kChunkSize; + delta += (params.n_chunks - 1) * kChunkSize; + dout += (params.n_chunks - 1) * kChunkSize; + Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { + input_t u_vals[kNItems]; + input_t delta_vals_load[kNItems]; + input_t dout_vals_load[kNItems]; + __syncthreads(); + load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); + u -= kChunkSize; + __syncthreads(); + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + // Will reload delta at the same location if kDeltaSoftplus + if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } + __syncthreads(); + load_input(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + dout -= kChunkSize; + + float dout_vals[kNItems], delta_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dout_vals[i] = float(dout_vals_load[i]); + delta_vals[i] = float(delta_vals_load[i]) + delta_bias; + if constexpr (kDeltaSoftplus) { + delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; + } + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * params.z_d_stride + chunk * kChunkSize; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * params.out_d_stride + chunk * kChunkSize; + input_t *dz = reinterpret_cast(params.dz_ptr) + batch_id * params.dz_batch_stride + + dim_id * params.dz_d_stride + chunk * kChunkSize; + input_t z_vals[kNItems], out_vals[kNItems]; + __syncthreads(); + load_input(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize); + float dz_vals[kNItems], z_silu_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); + z_silu_vals[i] = z_val * z_sigmoid_val; + dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val + * (1.0f + z_val * (1.0f - z_sigmoid_val)); + dout_vals[i] *= z_silu_vals[i]; + } + __syncthreads(); + store_output(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize); + if (params.out_z_ptr != nullptr) { // Recompute and store out_z + float out_z_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; } + // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { + // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]); + // } + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * params.out_z_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize); + } + } + + float du_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); } + + float ddelta_vals[kNItems] = {0}; + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + const weight_t A_val = A[state_idx * params.A_dstate_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + weight_t A_scaled; + constexpr float kLog2e = M_LOG2E; + if constexpr (!kIsComplex) { + A_scaled = A_val * kLog2e; + } else { + A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_); + } + weight_t B_val, C_val; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (!kIsVariableB) { + B_val = B[state_idx * params.B_dstate_stride]; + } else { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + if constexpr (!kIsVariableC) { + C_val = C[state_idx * params.C_dstate_stride]; + } else { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + // const weight_t A_val = smem_a[state_idx]; + scan_t thread_data[kNItems], thread_reverse_data[kNItems]; + if constexpr (!kIsComplex) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp; + } + thread_reverse_data[i].y = dout_vals[i] * + (!kIsVariableC + ? (!kIsVariableB ? B_val * C_val : C_val) + : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); + } + __syncthreads(); + thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) + : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float dx = thread_reverse_data[i].y; + const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; + du_vals[i] += ddelta_u * delta_vals[i]; + const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; + dA_val += dx * delta_vals[i] * a; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += dout_vals[i] * thread_data[i].y; + } + } + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + if constexpr (kIsVariableB) { + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); + } + if constexpr (kIsVariableC) { + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); + } + const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; + weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; + weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float2 dA_dBC_val = make_float2(dA_val, dBC_val); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = dA_dBC_val.x; + if (threadIdx.x == 0) { + smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx]; + } + } else { + dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; + } + } else { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + // Pytorch's implementation of complex exp (which calls thrust) is very slow + complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled); + weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp.real_; + thread_reverse_data[i - 1].y = -delta_a_exp.imag_; + } + complex_t dout_BC = 2 * dout_vals[i] + * conj(!kIsVariableC + ? (!kIsVariableB ? B_val * C_val : C_val) + : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); + thread_reverse_data[i].z = dout_BC.real_; + thread_reverse_data[i].w = dout_BC.imag_; + } + __syncthreads(); + complex_t delta_a_exp = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) + : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; + thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; + thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + complex_t x = complex_t(thread_data[i].z, thread_data[i].w); + complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); + float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += (2 * dout_vals[i]) * conj(x); + } + } + const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i])); + du_vals[i] += ddelta_u * delta_vals[i]; + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_; + dA_val += delta_vals[i] * dx * a_conj; + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; + if constexpr (kIsVariableB) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dB_vals_f[i * 2] = dB_vals[i].real_; + dB_vals_f[i * 2 + 1] = dB_vals[i].imag_; + } + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); + } + if constexpr (kIsVariableC) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dC_vals_f[i * 2] = dC_vals[i].real_; + dC_vals_f[i * 2 + 1] = dC_vals[i].imag_; + } + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); + } + const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; + float *dB_cur = reinterpret_cast(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + float *dC_cur = reinterpret_cast(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems * 2; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); + dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); + if (threadIdx.x == 0) { + smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx]; + } + } else { + dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; + } + } + } + + if constexpr (kDeltaSoftplus) { + __syncthreads(); + input_t delta_vals_load[kNItems]; + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + delta -= kChunkSize; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float delta_val = float(delta_vals_load[i]) + delta_bias; + float delta_val_neg_exp = expf(-delta_val); + ddelta_vals[i] = delta_val <= 20.f + ? ddelta_vals[i] / (1.f + delta_val_neg_exp) + : ddelta_vals[i]; + } + } + for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; } + + input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + + dim_id * params.du_d_stride + chunk * kChunkSize; + input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + + dim_id * params.ddelta_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize); + __syncthreads(); + store_output(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize); + + Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); + Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); + } + if (params.dD_ptr != nullptr) { + dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val); + if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); } + } + if (params.ddelta_bias_ptr != nullptr) { + __syncthreads(); + ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val); + if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); } + } + for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]); + weight_t dBC_val; + if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; } + if constexpr (!kIsVariableB) { + gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]), + !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val); + } + if constexpr (!kIsVariableC) { + gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]), + !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val); + } + } +} + +template +void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + // using Ktraits = Selective_Scan_bwd_kernel_traits; + // TODO: check this + constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); +} + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream); + } +} \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_common.h b/csrc/selective_scan/selective_scan_common.h new file mode 100644 index 00000000..9140dcdf --- /dev/null +++ b/csrc/selective_scan/selective_scan_common.h @@ -0,0 +1,221 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For scalar_value_type + +#define MAX_DSTATE 256 + +using complex_t = c10::complex; + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +inline __device__ float3 operator+(const float3 &a, const float3 &b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +inline __device__ float4 operator+(const float4 & a, const float4 & b){ + return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { + #pragma unroll + for (int i = 0; i < N; ++i) { dst[i] = src[i]; } + } +}; + +template +struct Converter{ + static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } + } +}; + +#if __CUDA_ARCH__ >= 800 +template +struct Converter{ + static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp +// and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 +__device__ __forceinline__ complex_t cexp2f(complex_t z) { + float t = exp2f(z.real_); + float c, s; + sincosf(z.imag_, &s, &c); + return complex_t(c * t, s * t); +} + +__device__ __forceinline__ complex_t cexpf(complex_t z) { + float t = expf(z.real_); + float c, s; + sincosf(z.imag_, &s, &c); + return complex_t(c * t, s * t); +} + +template struct SSMScanOp; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { + return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); + } +}; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { + complex_t a0 = complex_t(ab0.x, ab0.y); + complex_t b0 = complex_t(ab0.z, ab0.w); + complex_t a1 = complex_t(ab1.x, ab1.y); + complex_t b1 = complex_t(ab1.z, ab1.w); + complex_t out_a = a1 * a0; + complex_t out_b = a1 * b0 + b1; + return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); + } +}; + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +template struct SSMScanPrefixCallbackOp { + using scan_t = std::conditional_t, float2, float4>; + scan_t running_prefix; + // Constructor + __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ scan_t operator()(scan_t block_aggregate) { + scan_t old_prefix = running_prefix; + running_prefix = SSMScanOp()(running_prefix, block_aggregate); + return old_prefix; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void load_input(typename Ktraits::input_t *u, + typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadT::TempStorage &smem_load, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_vec = reinterpret_cast(smem_load); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadVecT(smem_load_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + ); + } else { + Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); + } +} + +template +inline __device__ void load_weight(typename Ktraits::input_t *Bvar, + typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, + int seqlen) { + constexpr int kNItems = Ktraits::kNItems; + if constexpr (!Ktraits::kIsComplex) { + typename Ktraits::input_t B_vals_load[kNItems]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + // #pragma unroll + // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } + Converter::to_float(B_vals_load, B_vals); + } else { + typename Ktraits::input_t B_vals_load[kNItems * 2]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } + } +} + +template +inline __device__ void store_output(typename Ktraits::input_t *out, + const float (&out_vals)[Ktraits::kNItems], + typename Ktraits::BlockStoreT::TempStorage &smem_store, + int seqlen) { + typename Ktraits::input_t write_vals[Ktraits::kNItems]; + #pragma unroll + for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_store_vec = reinterpret_cast(smem_store); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockStoreVecT(smem_store_vec).Store( + reinterpret_cast(out), + reinterpret_cast(write_vals) + ); + } else { + Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); + } +} diff --git a/csrc/selective_scan/selective_scan_fwd_bf16.cu b/csrc/selective_scan/selective_scan_fwd_bf16.cu new file mode 100644 index 00000000..2b8615b1 --- /dev/null +++ b/csrc/selective_scan/selective_scan_fwd_bf16.cu @@ -0,0 +1,10 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_fwd_fp16.cu b/csrc/selective_scan/selective_scan_fwd_fp16.cu new file mode 100644 index 00000000..015e2a0e --- /dev/null +++ b/csrc/selective_scan/selective_scan_fwd_fp16.cu @@ -0,0 +1,10 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_fwd_fp32.cu b/csrc/selective_scan/selective_scan_fwd_fp32.cu new file mode 100644 index 00000000..c142fe02 --- /dev/null +++ b/csrc/selective_scan/selective_scan_fwd_fp32.cu @@ -0,0 +1,10 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh new file mode 100644 index 00000000..440a2091 --- /dev/null +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -0,0 +1,345 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "static_switch.h" + +template +struct Selective_Scan_fwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsComplex = std::is_same_v; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kHasZ = kHasZ_; + + static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + + using vec_t = typename BytesToType::Type; + using scan_t = std::conditional_t; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr bool kIsComplex = Ktraits::kIsComplex; + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); + // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; + + float D_val[kNRows] = {0}; + if (params.D_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; + } + } + float delta_bias[kNRows] = {0}; + if (params.delta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; + } + } + + // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; + // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; + // } + + constexpr int kChunkSize = kNThreads * kNItems; + for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; + if (params.delta_softplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + delta_u_vals[r][i] = delta_vals[r][i] * u_val; + out_vals[r][i] = D_val[r] * u_val; + } + } + + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + if constexpr (!kIsComplex) { + A_val[r] *= kLog2e; + } else { + A_val[r].real_ *= kLog2e; + } + } + // This variable holds B * C if both B and C are constant across seqlen. If only B varies + // across seqlen, this holds C. If only C varies across seqlen, this holds B. + // If both B and C vary, this is unused. + weight_t BC_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (kIsVariableB) { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + } + if constexpr (kIsVariableC) { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } + } + if constexpr (!kIsVariableB && !kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if (r > 0) { __syncthreads(); } // Scan could be using the same smem + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if constexpr (!kIsComplex) { + thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), + !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } else { + // Pytorch's implementation of complex exp (which calls thrust) is very slow + complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]); + weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]; + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f); + } + } + } + } + // Initialize running total + scan_t running_prefix; + if constexpr (!kIsComplex) { + // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read + running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + } else { + running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + } + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx] = prefix_op.running_prefix; + x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const weight_t C_val = !kIsVariableC + ? BC_val[r] + : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); + if constexpr (!kIsComplex) { + out_vals[r][i] += thread_data[i].y * C_val; + } else { + out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2; + } + } + } + } + + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + input_t z_vals[kNItems]; + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + out_vals[r][i] *= z_val / (1 + expf(-z_val)); + } + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + } + + Bvar += kChunkSize * (!kIsComplex ? 1 : 2); + Cvar += kChunkSize * (!kIsComplex ? 1 : 2); + } +} + +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block + // processing 1 row. + constexpr int kNRows = 1; + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); +} + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } +} diff --git a/csrc/selective_scan/static_switch.h b/csrc/selective_scan/static_switch.h new file mode 100644 index 00000000..7920ac04 --- /dev/null +++ b/csrc/selective_scan/static_switch.h @@ -0,0 +1,25 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/selective_scan/uninitialized_copy.cuh b/csrc/selective_scan/uninitialized_copy.cuh new file mode 100644 index 00000000..630622dd --- /dev/null +++ b/csrc/selective_scan/uninitialized_copy.cuh @@ -0,0 +1,69 @@ +/****************************************************************************** + * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include + +#include + + +namespace detail +{ + +#if defined(_NVHPC_CUDA) +template +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + // NVBug 3384810 + new (ptr) T(::cuda::std::forward(val)); +} +#else +template ::value, + int + >::type = 0> +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + *ptr = ::cuda::std::forward(val); +} + +template ::value, + int + >::type = 0> +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + new (ptr) T(::cuda::std::forward(val)); +} +#endif + +} // namespace detail diff --git a/evals/lm_harness_eval.py b/evals/lm_harness_eval.py new file mode 100644 index 00000000..d09d4053 --- /dev/null +++ b/evals/lm_harness_eval.py @@ -0,0 +1,39 @@ +import torch + +import transformers +from transformers import AutoTokenizer + +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + +from lm_eval.api.model import LM +from lm_eval.models.huggingface import HFLM +from lm_eval.api.registry import register_model +from lm_eval.__main__ import cli_evaluate + + +@register_model("mamba") +class MambaEvalWrapper(HFLM): + + AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM + + def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda", + dtype=torch.float16): + LM.__init__(self) + self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype) + self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + self.vocab_size = self.tokenizer.vocab_size + self._batch_size = batch_size if batch_size is None else 64 + self._max_length = max_length + self._device = torch.device(device) + + @property + def batch_size(self): + return self._batch_size + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + raise NotImplementedError() + + +if __name__ == "__main__": + cli_evaluate() diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py new file mode 100644 index 00000000..76d545e4 --- /dev/null +++ b/mamba_ssm/__init__.py @@ -0,0 +1,5 @@ +__version__ = "1.0.0" + +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn +from mamba_ssm.modules.mamba_simple import Mamba +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel diff --git a/mamba_ssm/models/__init__.py b/mamba_ssm/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py new file mode 100644 index 00000000..383f773f --- /dev/null +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -0,0 +1,233 @@ +# Copyright (c) 2023, Albert Gu, Tri Dao. + +import math +from functools import partial + +from collections import namedtuple + +import torch +import torch.nn as nn + +from mamba_ssm.modules.mamba_simple import Mamba, Block +from mamba_ssm.utils.generation import GenerationMixin +from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf + +try: + from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn +except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + + +def create_block( + d_model, + ssm_cfg=None, + norm_epsilon=1e-5, + rms_norm=False, + residual_in_fp32=False, + fused_add_norm=False, + layer_idx=None, + device=None, + dtype=None, +): + if ssm_cfg is None: + ssm_cfg = {} + factory_kwargs = {"device": device, "dtype": dtype} + mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) + norm_cls = partial( + nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs + ) + block = Block( + d_model, + mixer_cls, + norm_cls=norm_cls, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + ) + block.layer_idx = layer_idx + return block + + +# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 +def _init_weights( + module, + n_layer, + initializer_range=0.02, # Now only used for embedding layer. + rescale_prenorm_residual=True, + n_residuals_per_layer=1, # Change to 2 if we have MLP +): + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(n_residuals_per_layer * n_layer) + + +class MixerModel(nn.Module): + def __init__( + self, + d_model: int, + n_layer: int, + vocab_size: int, + ssm_cfg=None, + norm_epsilon: float = 1e-5, + rms_norm: bool = False, + initializer_cfg=None, + fused_add_norm=False, + residual_in_fp32=False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + + self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) + + # We change the order of residual and layer norm: + # Instead of LN -> Attn / MLP -> Add, we do: + # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and + # the main branch (output of MLP / Mixer). The model definition is unchanged. + # This is for performance reason: we can fuse add + layer_norm. + self.fused_add_norm = fused_add_norm + if self.fused_add_norm: + if layer_norm_fn is None or rms_norm_fn is None: + raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") + + self.layers = nn.ModuleList( + [ + create_block( + d_model, + ssm_cfg=ssm_cfg, + norm_epsilon=norm_epsilon, + rms_norm=rms_norm, + residual_in_fp32=residual_in_fp32, + fused_add_norm=fused_add_norm, + layer_idx=i, + **factory_kwargs, + ) + for i in range(n_layer) + ] + ) + + self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( + d_model, eps=norm_epsilon, **factory_kwargs + ) + + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return { + i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + for i, layer in enumerate(self.layers) + } + + def forward(self, input_ids, inference_params=None): + hidden_states = self.embedding(input_ids) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + hidden_states, residual, inference_params=inference_params + ) + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn + hidden_states = fused_add_norm_fn( + hidden_states, + self.norm_f.weight, + self.norm_f.bias, + eps=self.norm_f.eps, + residual=residual, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, + ) + return hidden_states + + +class MambaLMHeadModel(nn.Module, GenerationMixin): + + def __init__( + self, + d_model: int, + n_layer: int, + vocab_size: int, + initializer_cfg=None, + pad_vocab_size_multiple: int = 1, + device=None, + dtype=None, + **backbone_kwargs, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if vocab_size % pad_vocab_size_multiple != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + self.backbone = MixerModel( + d_model=d_model, + n_layer=n_layer, + vocab_size=vocab_size, + initializer_cfg=initializer_cfg, + **backbone_kwargs, + **factory_kwargs, + ) + self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) + + # Initialize weights and apply final processing + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) + self.tie_weights() + + def tie_weights(self): + self.lm_head.weight = self.backbone.embedding.weight + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): + """ + "position_ids" is just to be compatible with Transformer generation. We don't use it. + num_last_tokens: if > 0, only return the logits for the last n tokens + """ + hidden_states = self.backbone(input_ids, inference_params=inference_params) + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] + lm_logits = self.lm_head(hidden_states) + CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) + return CausalLMOutput(logits=lm_logits) + + @classmethod + def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): + config = load_config_hf(pretrained_model_name) + model = cls(**config, device=device, dtype=dtype, **kwargs) + model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) + return model diff --git a/mamba_ssm/modules/__init__.py b/mamba_ssm/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py new file mode 100644 index 00000000..6b375513 --- /dev/null +++ b/mamba_ssm/modules/mamba_simple.py @@ -0,0 +1,354 @@ +# Copyright (c) 2023, Tri Dao, Albert Gu. + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from einops import rearrange, repeat + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn, causal_conv1d_update = None + +try: + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn +except ImportError: + selective_scan_fn, mamba_inner_fn = None, None, None + +try: + from mamba_ssm.ops.triton.selective_state_update import selective_state_update +except ImportError: + selective_state_update = None + +try: + from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn +except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + + +class Mamba(nn.Module): + def __init__( + self, + d_model, + d_state=16, + d_conv=4, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + conv_bias=True, + bias=False, + use_fast_path=True, # Fused kernel options + layer_idx=None, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + self.use_fast_path = use_fast_path + self.layer_idx = layer_idx + + self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) + + self.conv1d = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + bias=conv_bias, + kernel_size=d_conv, + groups=self.d_inner, + padding=d_conv - 1, + **factory_kwargs, + ) + + self.activation = "silu" + self.act = nn.SiLU() + + self.x_proj = nn.Linear( + self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs + ) + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = self.dt_rank**-0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(self.dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + self.dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + self.dt_proj.bias._no_reinit = True + + # S4D real initialization + A = repeat( + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=self.d_inner, + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + def forward(self, hidden_states, inference_params=None): + """ + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + batch, seqlen, dim = hidden_states.shape + + conv_state, ssm_state = None, None + if inference_params is not None: + conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # The states are updated inplace + out, _, _ = self.step(hidden_states, conv_state, ssm_state) + return out + + # We do matmul and transpose BLH -> HBL at the same time + xz = rearrange( + self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), + "d (b l) -> b d l", + l=seqlen, + ) + if self.in_proj.bias is not None: + xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") + + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + # In the backward pass we write dx and dz next to each other to avoid torch.cat + if self.use_fast_path and inference_params is None: # Doesn't support outputting the states + out = mamba_inner_fn( + xz, + self.conv1d.weight, + self.conv1d.bias, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias, + A, + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + ) + else: + x, z = xz.chunk(2, dim=1) + # Compute short convolution + if conv_state is not None: + conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W) + if causal_conv1d_fn is None: + x = self.act(self.conv1d(x)[..., :seqlen]) + else: + assert self.activation in ["silu", "swish"] + x = causal_conv1d_fn( + x, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + # We're careful here about the layout, to avoid extra transposes. + # We want dt to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) + dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj.weight @ dt.t() + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + assert self.activation in ["silu", "swish"] + y = selective_scan_fn( + x, + dt, + A, + B, + C, + self.D.float(), + z=z, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=ssm_state is not None, + ) + if ssm_state is not None: + y, last_state = y + ssm_state.copy_(last_state) + y = rearrange(y, "b d l -> b l d") + out = self.out_proj(y) + return out + + def step(self, hidden_states, conv_state, ssm_state): + dtype = hidden_states.dtype + assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" + xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + x, z = xz.chunk(2, dim=-1) # (B D) + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = x + x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv1d.bias is not None: + x = x + self.conv1d.bias + x = self.act(x).to(dtype=dtype) + else: + x = causal_conv1d_update( + x, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + x_db = self.x_proj(x) # (B dt_rank+2*d_state) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + # Don't add dt_bias here + dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + + # SSM step + if selective_state_update is None: + # Discretize A and B + dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) + dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) + dB = torch.einsum("bd,bn->bdn", dt, B) + ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) + y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) + y = y + self.D.to(dtype) * x + y = y * self.act(z) # (B D) + else: + y = selective_state_update( + ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ) + + out = self.out_proj(y) + return out.unsqueeze(1), conv_state, ssm_state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.out_proj.weight.device + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype + ) + ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype + # ssm_dtype = torch.float32 + ssm_state = torch.zeros( + batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype + ) + return conv_state, ssm_state + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict: + batch_shape = (batch_size,) + conv_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_conv, + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_state, + device=self.dt_proj.weight.device, + dtype=self.dt_proj.weight.dtype, + # dtype=torch.float32, + ) + inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + # TODO: What if batch size changes between generation, and we reuse the same states? + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state + + +class Block(nn.Module): + def __init__( + self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False + ): + """ + Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" + + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA/MLP -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Add -> LN -> Mixer, returning both + the hidden_states (output of the mixer) and the residual. + This is purely for performance reasons, as we can fuse add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.mixer = mixer_cls(dim) + self.norm = norm_cls(dim) + if self.fused_add_norm: + assert RMSNorm is not None, "RMSNorm import fails" + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" + + def forward( + self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Mixer(LN(residual)) + """ + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + hidden_states = self.mixer(hidden_states, inference_params=inference_params) + return hidden_states, residual + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) diff --git a/mamba_ssm/ops/__init__.py b/mamba_ssm/ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py new file mode 100644 index 00000000..0932c29f --- /dev/null +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -0,0 +1,345 @@ +# Copyright (c) 2023, Tri Dao, Albert Gu. + +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd + +from einops import rearrange, repeat + +from causal_conv1d import causal_conv1d_fn +import causal_conv1d_cuda +import selective_scan_cuda + + +class SelectiveScanFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False # option to recompute out_z, not used here + ) + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + return (du, ddelta, dA, dB, dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, + None) + + +def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + + +def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +class MambaInnerFn(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): + """ + xz: (batch, dim, seqlen) + """ + assert checkpoint_lvl in [0, 1] + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + if torch.is_autocast_enabled(): + x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) + if out_proj_bias is not None else None) + if xz.stride(-1) != 1: + xz = xz.contiguous() + conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") + x, z = xz.chunk(2, dim=1) + conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) + # We're being very careful here about the layout, to avoid extra transposes. + # We want delta to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) + ctx.is_variable_B = B is None + ctx.is_variable_C = C is None + ctx.B_proj_bias_is_None = B_proj_bias is None + ctx.C_proj_bias_is_None = C_proj_bias is None + if B is None: # variable B + B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) + if B_proj_bias is not None: + B = B + B_proj_bias.to(dtype=B.dtype) + if not A.is_complex(): + # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + else: + B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + else: + if B.stride(-1) != 1: + B = B.contiguous() + if C is None: # variable C + C = x_dbl[:, -d_state:] # (bl dstate) + if C_proj_bias is not None: + C = C + C_proj_bias.to(dtype=C.dtype) + if not A.is_complex(): + # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + else: + C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + else: + if C.stride(-1) != 1: + C = C.contiguous() + if D is not None: + D = D.contiguous() + out, scan_intermediates, out_z = selective_scan_cuda.fwd( + conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus + ) + ctx.delta_softplus = delta_softplus + ctx.out_proj_bias_is_None = out_proj_bias is None + ctx.checkpoint_lvl = checkpoint_lvl + if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass + conv1d_out, delta = None, None + ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, + delta_proj_weight, out_proj_weight, conv1d_out, delta, + A, B, C, D, delta_bias, scan_intermediates, out) + return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) + + @staticmethod + @custom_bwd + def backward(ctx, dout): + # dout: (batch, seqlen, dim) + (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, + conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + x, z = xz.chunk(2, dim=1) + if dout.stride(-1) != 1: + dout = dout.contiguous() + if ctx.checkpoint_lvl == 1: + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), + "d (b l) -> b d l", l = L) + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + dxz = torch.empty_like(xz) # (batch, dim, seqlen) + dx, dz = dxz.chunk(2, dim=1) + dout = rearrange(dout, "b l e -> e (b l)") + dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) + dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( + conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, + ctx.delta_softplus, + True # option to recompute out_z + ) + dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) + dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None + dD = dD if D is not None else None + dx_dbl = torch.empty_like(x_dbl) + dB_proj_bias = None + if ctx.is_variable_B: + if not A.is_complex(): + dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() + else: + dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None + dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) + dB = None + dC_proj_bias = None + if ctx.is_variable_C: + if not A.is_complex(): + dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() + else: + dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None + dx_dbl[:, -d_state:] = dC # (bl d) + dC = None + ddelta = rearrange(ddelta, "b d l -> d (b l)") + ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) + dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) + dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") + dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) + dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) + dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( + x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True + ) + dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None + dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") + return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, + dout_proj_weight, dout_proj_bias, + dA, dB, dC, dD, + ddelta_bias if delta_bias is not None else None, + dB_proj_bias, dC_proj_bias, None) + + +def mamba_inner_fn( + xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True +): + return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) + + +def mamba_inner_ref( + xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True +): + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + x, z = xz.chunk(2, dim=1) + x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") + # We're being very careful here about the layout, to avoid extra transposes. + # We want delta to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) + delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() + delta = rearrange(delta, "d (b l) -> b d l", l=L) + if B is None: # variable B + B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) + if B_proj_bias is not None: + B = B + B_proj_bias.to(dtype=B.dtype) + if not A.is_complex(): + B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + else: + B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + if C is None: # variable B + C = x_dbl[:, -d_state:] # (bl d) + if C_proj_bias is not None: + C = C + C_proj_bias.to(dtype=C.dtype) + if not A.is_complex(): + C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + else: + C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) + return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) diff --git a/mamba_ssm/ops/triton/__init__.py b/mamba_ssm/ops/triton/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mamba_ssm/ops/triton/layernorm.py b/mamba_ssm/ops/triton/layernorm.py new file mode 100644 index 00000000..8df9d042 --- /dev/null +++ b/mamba_ssm/ops/triton/layernorm.py @@ -0,0 +1,636 @@ +# Copyright (c) 2023, Tri Dao. +# Implement residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_fwd, custom_bwd + +import triton +import triton.language as tl + + +def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + return out if not prenorm else (out, x) + + +def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + out = out.to(dtype) + return out if not prenorm else (out, x) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + RESIDUAL_OUT, # pointer to the residual + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + N, # number of columns in X + eps, # epsilon to avoid division by zero + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd( + x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + assert y.stride(-1) == 1 + if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + y, + weight, + bias, + residual, + residual_out, + mean, + rstd, + x.stride(0), + y.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + N, + eps, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + bias is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype + return y, mean, rstd, residual_out if residual_out is not None else x + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + DRESIDUAL_IN, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + has_residual=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + dresidual_in, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype: + dresidual_in = dx + return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) + + +class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm + ) + ctx.save_for_backward(residual_out, weight, bias, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + y = y.reshape(x_shape_og) + return y if not prenorm else (y, residual_out.reshape(x_shape_og)) + + @staticmethod + def backward(ctx, dy, *args): + x, weight, bias, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dw, db, dresidual_in = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm) + + +def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6): + return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=True, + ) + + +class LayerNormLinearFn(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + norm_weight = norm_weight.contiguous() + if norm_bias is not None: + norm_bias = norm_bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) diff --git a/mamba_ssm/ops/triton/selective_state_update.py b/mamba_ssm/ops/triton/selective_state_update.py new file mode 100644 index 00000000..fa95de73 --- /dev/null +++ b/mamba_ssm/ops/triton/selective_state_update.py @@ -0,0 +1,192 @@ +# Copyright (c) 2023, Tri Dao. + +"""We want triton==2.1.0 for this +""" + +import math +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + + +@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.jit +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, + # Matrix dimensions + batch, dim, dstate, + # Strides + stride_state_batch, stride_state_dim, stride_state_dstate, + stride_x_batch, stride_x_dim, + stride_dt_batch, stride_dt_dim, + stride_dt_bias_dim, + stride_A_dim, stride_A_dstate, + stride_B_batch, stride_B_dstate, + stride_C_batch, stride_C_dstate, + stride_D_dim, + stride_z_batch, stride_z_dim, + stride_out_batch, stride_out_dim, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + state_ptr += pid_b * stride_state_batch + x_ptr += pid_b * stride_x_batch + dt_ptr += pid_b * stride_dt_batch + B_ptr += pid_b * stride_B_batch + C_ptr += pid_b * stride_C_batch + if HAS_Z: + z_ptr += pid_b * stride_z_batch + out_ptr += pid_b * stride_out_batch + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_D: + D_ptrs = D_ptr + offs_m * stride_D_dim + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = tl.log(1.0 + tl.exp(dt)) + A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + dB = B[None, :] * dt[:, None] + state = state * dA + dB * x[:, None] + tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + +def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) + x: (batch, dim) + dt: (batch, dim) + A: (dim, dstate) + B: (batch, dstate) + C: (batch, dstate) + D: (dim,) + z: (batch, dim) + dt_bias: (dim,) + Return: + out: (batch, dim) + """ + batch, dim, dstate = state.shape + assert x.shape == (batch, dim) + assert dt.shape == x.shape + assert A.shape == (dim, dstate) + assert B.shape == (batch, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (dim,) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (dim,) + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch) + z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0)) + # We don't want autotune since it will overwrite the state + # We instead tune by hand. + BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 + else ((16, 4) if dstate <= 32 else + ((8, 4) if dstate <= 64 else + ((4, 4) if dstate <= 128 else + ((4, 8)))))) + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, x, dt, dt_bias, A, B, C, D, z, out, + batch, dim, dstate, + state.stride(0), state.stride(1), state.stride(2), + x.stride(0), x.stride(1), + dt.stride(0), dt.stride(1), + dt_bias.stride(0) if dt_bias is not None else 0, + A.stride(0), A.stride(1), + B.stride(0), B.stride(1), + C.stride(0), C.stride(1), + D.stride(0) if D is not None else 0, + z_strides[0], z_strides[1], + out.stride(0), out.stride(1), + dt_softplus, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + return out + + +def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) + x: (batch, dim) + dt: (batch, dim) + A: (dim, dstate) + B: (batch, dstate) + C: (batch, dstate) + D: (dim,) + z: (batch, dim) + dt_bias: (dim,) + Return: + out: (batch, dim) + """ + batch, dim, dstate = state.shape + assert x.shape == (batch, dim) + assert dt.shape == x.shape + assert A.shape == (dim, dstate) + assert B.shape == (batch, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (dim,) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (dim,) + dt = dt + dt_bias + dt = F.softplus(dt) if dt_softplus else dt + dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate) + dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate) + state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate + out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C) + if D is not None: + out += (x * D).to(out.dtype) + return (out if z is None else out * F.silu(z)).to(x.dtype) diff --git a/mamba_ssm/utils/__init__.py b/mamba_ssm/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mamba_ssm/utils/generation.py b/mamba_ssm/utils/generation.py new file mode 100644 index 00000000..9d766b29 --- /dev/null +++ b/mamba_ssm/utils/generation.py @@ -0,0 +1,377 @@ +# Copyright (c) 2023, Albert Gu, Tri Dao. +import gc +import time +from collections import namedtuple +from dataclasses import dataclass, field +from functools import partial +from typing import Callable, Optional, Sequence, Union + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import Tensor +from torch.profiler import ProfilerActivity, profile, record_function +from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput + + +@dataclass +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + max_seqlen: int + max_batch_size: int + seqlen_offset: int = 0 + batch_size_offset: int = 0 + key_value_memory_dict: dict = field(default_factory=dict) + lengths_per_sample: Optional[Tensor] = None + + def reset(self, max_seqlen, max_batch_size): + self.max_seqlen = max_seqlen + self.max_batch_size = max_batch_size + self.seqlen_offset = 0 + if self.lengths_per_sample is not None: + self.lengths_per_sample.zero_() + + +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 +def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf. Done in-place.""" + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(indices_to_remove, float("-Inf")) + + +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 +def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf. Done in-place.""" + if top_p <= 0.0 or top_p >= 1.0: + return + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - top_p) + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits.masked_fill_(indices_to_remove, float("-inf")) + + +def sample(logits, top_k=1, top_p=0.0, temperature=1.0): + """Sample from top-k logits. + Arguments: + logits: Tensor of shape (batch_size, vocab_size) + """ + if top_k == 1: # Short-circuit for greedy decoding + return logits.argmax(dim=-1) + else: + if top_p > 0.0: + assert top_p <= 1.0, "top-p should be in (0, 1]." + if top_k > 0: + top_k = min(top_k, logits.size(-1)) # Safety check + logits_top, indices = torch.topk(logits, top_k, dim=-1) + if temperature != 1.0: + logits_top /= temperature + modify_logits_for_top_p_filtering(logits_top, top_p) + return indices[ + torch.arange(indices.shape[0], device=indices.device), + torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), + ] + else: + # Clone so that when we modify for top_p we don't change the original logits + logits_top = logits / temperature if temperature != 1.0 else logits.clone() + modify_logits_for_top_p_filtering(logits_top, top_p) + return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( + dim=-1 + ) + + +@torch.inference_mode() +def decode( + input_ids, + model, + max_length, + top_k=1, + top_p=0.0, + temperature=1.0, + eos_token_id=None, + teacher_outputs=None, + vocab_size=None, + tensor_parallel=1, + cg=False, + enable_timing=False, +): + """Decoding, either greedy or with top-k or top-p sampling. + If top-k = 0, don't limit the number of candidates (pure sampling). + Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, + then top-p. + We assume that all sequences in the same batch have the same length. + + Arguments: + input_ids: (batch, seq_len) + max_length: int + teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the + logits, the next token is taken from the teacher_outputs. Useful for testing. + Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: + sequences: (batch, max_length) + scores: tuples of (batch, vocab_size) + """ + batch_size, seqlen_og = input_ids.shape + teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 + if cg: + if not hasattr(model, "_decoding_cache"): + model._decoding_cache = None + model._decoding_cache = update_graph_cache( + model, + model._decoding_cache, + batch_size, + seqlen_og, + max_length, + tensor_parallel=tensor_parallel, + ) + inference_params = model._decoding_cache.inference_params + inference_params.reset(max_length, batch_size) + else: + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) + + def get_logits(input_ids, inference_params): + decoding = inference_params.seqlen_offset > 0 + if decoding: + position_ids = torch.full( + (batch_size, 1), + inference_params.seqlen_offset, + dtype=torch.long, + device=input_ids.device, + ) + else: + position_ids = None + if not cg or not decoding: + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=1, + ).logits.squeeze(dim=1) + else: + logits = model._decoding_cache.run( + input_ids, position_ids, inference_params.seqlen_offset + ).squeeze(dim=1) + return logits[..., :vocab_size] if vocab_size is not None else logits + + def sample_tokens(logits, inference_params): + if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: + token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) + else: + token = teacher_outputs[:, inference_params.seqlen_offset] + # return rearrange(token, "b -> b 1") + return token.unsqueeze(1) + + def should_stop(current_token, inference_params): + if inference_params.seqlen_offset == 0: + return False + if eos_token_id is not None and (current_token == eos_token_id).all(): + return True + if inference_params.seqlen_offset >= max_length - 1: + return True + return False + + start = torch.cuda.Event(enable_timing=enable_timing) + end = torch.cuda.Event(enable_timing=enable_timing) + + if enable_timing: + if tensor_parallel > 1: + torch.distributed.barrier() + start.record() + scores, sequences = [], [input_ids] + while not should_stop(sequences[-1], inference_params): + scores.append(get_logits(sequences[-1], inference_params)) + inference_params.seqlen_offset += sequences[-1].shape[1] + sequences.append(sample_tokens(scores[-1], inference_params)) + if enable_timing: + end.record() + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") + output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput + return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) + + +class GenerationMixin: + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + raise NotImplementedError + + def generate( + self, + input_ids, + max_length, + top_k=1, + top_p=0.0, + temperature=1.0, + return_dict_in_generate=False, + output_scores=False, + **kwargs, + ): + output = decode( + input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs + ) + if not output_scores: + output.scores = None + return output if return_dict_in_generate else output.sequences + + +def allocate_inference_cache( + max_batch_size, + max_seqlen, + nheads, + headdim, + layers: Union[int, Sequence], + device, + dtype=torch.float16, +): + assert dtype in [torch.float16, torch.bfloat16, torch.float32] + kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim) + if isinstance(layers, int): + layers = range(layers) + return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers} + + +@dataclass +class DecodingCGCache: + max_batch_size: int = 0 + max_seqlen: int = 0 + device = None + dtype = None + callables: dict = field(default_factory=dict) + mempool = None + inference_params: Optional[InferenceParams] = None + run: Optional[Callable] = None + + +@torch.inference_mode() +def update_graph_cache( + model, + cache, + batch_size, + seqlen_og, + max_seqlen, + decoding_seqlens=(1,), + tensor_parallel=1, + dtype=None, + n_warmups=2, +): + if cache is None: + cache = DecodingCGCache() + param_example = next(iter(model.parameters())) + device = param_example.device + if dtype is None: + dtype = param_example.dtype + if ( + (device, dtype) != (cache.device, cache.dtype) + or batch_size > cache.max_batch_size + or max_seqlen > cache.max_seqlen + ): # Invalidate the cache + cache.callables = {} + cache.mempool = None + cache.inference_params = None + gc.collect() + cache.device, cache.dtype = device, dtype + cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen + if hasattr(model, "allocate_inference_cache"): + inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) + else: + headdim = getattr( + model.config, + "head_dim", + model.config.hidden_size // model.config.num_attention_heads, + ) + inf_cache = allocate_inference_cache( + batch_size, + max_seqlen, + model.config.num_attention_heads // tensor_parallel, + headdim, + model.config.num_hidden_layers, + device, + dtype, + ) + lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) + cache.inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_og, + key_value_memory_dict=inf_cache, + lengths_per_sample=lengths_per_sample, + ) + cache.mempool = torch.cuda.graphs.graph_pool_handle() + for decoding_seqlen in decoding_seqlens: + if (batch_size, decoding_seqlen) not in cache.callables: + cache.callables[batch_size, decoding_seqlen] = capture_graph( + model, + cache.inference_params, + batch_size, + max_seqlen, + decoding_seqlen=decoding_seqlen, + mempool=cache.mempool, + n_warmups=n_warmups, + ) + + def dispatch(input_ids, position_ids, seqlen): + batch_size, decoding_seqlen = input_ids.shape[:2] + return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) + + cache.run = dispatch + cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing + return cache + + +def capture_graph( + model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 +): + device = next(iter(model.parameters())).device + input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) + position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) + seqlen_offset_og = inference_params.seqlen_offset + inference_params.seqlen_offset = max_seqlen - decoding_seqlen + inference_params.lengths_per_sample[:] = inference_params.seqlen_offset + + # Warmup before capture + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(n_warmups): + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=decoding_seqlen, + ).logits + s.synchronize() + # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, + # which requires that graph launch and non-captured launch to not overlap (I think, + # that's how I interpret the documentation). I'm not sure if this is required. + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.cuda.current_stream().wait_stream(s) + # Captures the graph + # To allow capture, automatically sets a side stream as the current stream in the context + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=mempool): + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=decoding_seqlen, + ).logits + + def run(new_input_ids, new_position_ids, seqlen): + inference_params.lengths_per_sample[:] = seqlen + input_ids.copy_(new_input_ids) + position_ids.copy_(new_position_ids) + graph.replay() + return logits.clone() + + inference_params.seqlen_offset = seqlen_offset_og + return run diff --git a/mamba_ssm/utils/hf.py b/mamba_ssm/utils/hf.py new file mode 100644 index 00000000..0d7555ac --- /dev/null +++ b/mamba_ssm/utils/hf.py @@ -0,0 +1,23 @@ +import json + +import torch + +from transformers.utils import WEIGHTS_NAME, CONFIG_NAME +from transformers.utils.hub import cached_file + + +def load_config_hf(model_name): + resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) + return json.load(open(resolved_archive_file)) + + +def load_state_dict_hf(model_name, device=None, dtype=None): + # If not fp32, then we don't want to load directly to the GPU + mapped_device = "cpu" if dtype not in [torch.float32, None] else device + resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) + return torch.load(resolved_archive_file, map_location=mapped_device) + # Convert dtype before moving to GPU to save memory + if dtype is not None: + state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} + state_dict = {k: v.to(device=device) for k, v in state_dict.items()} + return state_dict diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..a4bbb97b --- /dev/null +++ b/setup.py @@ -0,0 +1,274 @@ +# Copyright (c) 2023, Albert Gu, Tri Dao. +import sys +import warnings +import os +import re +import ast +from pathlib import Path +from packaging.version import parse, Version +import platform + +from setuptools import setup, find_packages +import subprocess + +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import torch +from torch.utils.cpp_extension import ( + BuildExtension, + CppExtension, + CUDAExtension, + CUDA_HOME, +) + + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + +PACKAGE_NAME = "mamba_ssm" + +BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}" + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE" + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith("linux"): + return "linux_x86_64" + elif sys.platform == "darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) + return f"macosx_{mac_version}_x86_64" + elif sys.platform == "win32": + return "win_amd64" + else: + raise ValueError("Unsupported platform: {}".format(sys.platform)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + return nvcc_extra_args + ["--threads", "4"] + + +cmdclass = {} +ext_modules = [] + +if not SKIP_CUDA_BUILD: + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + + check_if_cuda_home_none(PACKAGE_NAME) + # Check, if CUDA11 is installed for compute capability 8.0 + cc_flag = [] + if CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("11.6"): + raise RuntimeError( + f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. " + "Note: make sure nvcc has a supported version by running nvcc -V." + ) + + cc_flag.append("-gencode") + cc_flag.append("arch=compute_70,code=sm_70") + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,code=sm_80") + if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # torch._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + + ext_modules.append( + CUDAExtension( + name="selective_scan_cuda", + sources=[ + "csrc/selective_scan/selective_scan.cpp", + "csrc/selective_scan/selective_scan_fwd_fp32.cu", + "csrc/selective_scan/selective_scan_fwd_fp16.cu", + "csrc/selective_scan/selective_scan_fwd_bf16.cu", + "csrc/selective_scan/selective_scan_bwd_fp32_real.cu", + "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", + "csrc/selective_scan/selective_scan_bwd_fp16_real.cu", + "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", + "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", + "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", + ], + extra_compile_args={ + "cxx": ["-O3"], + "nvcc": append_nvcc_threads( + [ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + "-lineinfo", + ] + + cc_flag + ), + }, + include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], + ) + ) + + +def get_package_version(): + with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f: + version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) + public_version = ast.literal_eval(version_match.group(1)) + local_version = os.environ.get("MAMBA_LOCAL_VERSION") + if local_version: + return f"{public_version}+{local_version}" + else: + return str(public_version) + + +def get_wheel_url(): + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build torch, not the one currently installed + # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + torch_cuda_version = parse(torch.version.cuda) + torch_version_raw = parse(torch.__version__) + # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 + # to save CI time. Minor versions should be compatible. + torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + mamba_ssm_version = get_package_version() + # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" + torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" + cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() + + # Determine wheel URL based on CUDA version, torch version, python version and OS + wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" + wheel_url = BASE_WHEEL_URL.format( + tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename + ) + return wheel_url, wheel_filename + + +class CachedWheelsCommand(_bdist_wheel): + """ + The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot + find an existing wheel (which is currently the case for all installs). We use + the environment parameters to detect whether there is already a pre-built version of a compatible + wheel available and short-circuits the standard full build pipeline. + """ + + def run(self): + if FORCE_BUILD: + return super().run() + + wheel_url, wheel_filename = get_wheel_url() + print("Guessing wheel URL: ", wheel_url) + try: + urllib.request.urlretrieve(wheel_url, wheel_filename) + + # Make the archive + # Lifted from the root wheel processing command + # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + + wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + print("Raw wheel path", wheel_path) + os.rename(wheel_filename, wheel_path) + except urllib.error.HTTPError: + print("Precompiled wheel not found. Building from source...") + # If the wheel could not be downloaded, build from source + super().run() + + +setup( + name=PACKAGE_NAME, + version=get_package_version(), + packages=find_packages( + exclude=( + "build", + "csrc", + "include", + "tests", + "dist", + "docs", + "benchmarks", + "mamba_ssm.egg-info", + ) + ), + author="Tri Dao, Albert Gu", + author_email="tri@tridao.me, agu@cs.cmu.edu", + description="Mamba state-space model", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/state-spaces/mamba", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + ], + ext_modules=ext_modules, + cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} + if ext_modules + else { + "bdist_wheel": CachedWheelsCommand, + }, + python_requires=">=3.7", + install_requires=[ + "torch", + "packaging", + "ninja", + "einops", + "triton", + "transformers", + "causal_conv1d", + ], +) diff --git a/tests/ops/test_selective_scan.py b/tests/ops/test_selective_scan.py new file mode 100644 index 00000000..8a834b3c --- /dev/null +++ b/tests/ops/test_selective_scan.py @@ -0,0 +1,247 @@ +# Copyright (C) 2023, Tri Dao. + +import math + +import torch +import torch.nn.functional as F +import pytest + +from einops import rearrange + +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref +from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref + + +# @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) +@pytest.mark.parametrize('wtype', [torch.float32]) +# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('itype', [torch.float32]) +# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) +@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) +# @pytest.mark.parametrize('seqlen', [128]) +# @pytest.mark.parametrize("return_last_state", [False, True]) +@pytest.mark.parametrize("return_last_state", [True]) +# @pytest.mark.parametrize('has_delta_bias', [False, True]) +@pytest.mark.parametrize('has_delta_bias', [True]) +# @pytest.mark.parametrize('delta_softplus', [False, True]) +@pytest.mark.parametrize('delta_softplus', [True]) +# @pytest.mark.parametrize('has_z', [False, True]) +@pytest.mark.parametrize('has_z', [True]) +# @pytest.mark.parametrize('has_D', [False, True]) +@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize("varBC_groups", [1, 2]) +# @pytest.mark.parametrize("varBC_groups", [1]) +# @pytest.mark.parametrize("is_variable_C", [False, True]) +@pytest.mark.parametrize("is_variable_C", [True]) +# @pytest.mark.parametrize("is_variable_B", [False, True]) +@pytest.mark.parametrize("is_variable_B", [True]) +def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, + delta_softplus, return_last_state, seqlen, itype, wtype): + if varBC_groups > 1 and (not is_variable_B or not is_variable_C): + pytest.skip() # This config is not applicable + device = 'cuda' + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + if has_z: # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + batch_size = 2 + dim = 4 + dstate = 8 + is_complex = wtype == torch.complex64 + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() + if not is_variable_B: + B_shape = (dim, dstate) + elif varBC_groups == 1: + B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, + requires_grad=True) + if not is_variable_C: + C_shape = (dim, dstate) + elif varBC_groups == 1: + C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, + requires_grad=True) + if has_D: + D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + else: + D = None + if has_z: + z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + else: + z = None + if has_delta_bias: + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() + else: + delta_bias = None + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() + A_ref = A.detach().clone().requires_grad_() + B_ref = B.detach().clone().requires_grad_() + C_ref = C.detach().clone().requires_grad_() + D_ref = D.detach().clone().requires_grad_() if D is not None else None + z_ref = z.detach().clone().requires_grad_() if z is not None else None + u_ref = u.detach().clone().requires_grad_() + delta_ref = delta.detach().clone().requires_grad_() + delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None + out, *rest = selective_scan_fn( + u, delta, A, B, C, D, z=z, + delta_bias=delta_bias, delta_softplus=delta_softplus, + return_last_state=return_last_state + ) + if return_last_state: + state = rest[0] + out_ref, *rest = selective_scan_ref( + u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, + delta_bias=delta_bias_ref, delta_softplus=delta_softplus, + return_last_state=return_last_state + ) + if return_last_state: + state_ref = rest[0] + # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + # dt_u = delta * u + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + if return_last_state: + print(f'State max diff: {(state - state_ref).abs().max().item()}') + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + + g = torch.randn_like(out) + out_ref.backward(g) + out.backward(g) + + print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}') + print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}') + print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') + print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') + print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') + if has_D: + print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') + if has_z: + print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}') + if has_delta_bias: + print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') + + assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) + assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) + assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) + assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, + atol=atolw if not is_variable_B else atol) + assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, + atol=atolw if not is_variable_C else atol) + if has_D: + assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) + if has_z: + assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) + if has_delta_bias: + assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) + + +@pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) +# @pytest.mark.parametrize('wtype', [torch.complex64]) +# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('itype', [torch.float32]) +# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) +@pytest.mark.parametrize('seqlen', [128]) +@pytest.mark.parametrize("is_variable_C", [False, True]) +# @pytest.mark.parametrize("is_variable_C", [False]) +@pytest.mark.parametrize("is_variable_B", [False, True]) +# @pytest.mark.parametrize("is_variable_B", [True]) +def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype): + device = 'cuda' + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + batch_size = 2 + dim = 768 + dstate = 8 + dt_rank = 48 + is_complex = wtype == torch.complex64 + xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True) + conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True) + conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate + * (1 if not is_complex else 2), + dim, device=device, dtype=itype, requires_grad=True) + delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True) + out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True) + out_proj_bias = None + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() + B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) + if not is_variable_B else None) + C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) + if not is_variable_C else None) + D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() + B_proj_bias = None + C_proj_bias = None + xz_ref = xz.detach().clone().requires_grad_() + conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_() + conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_() + x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_() + delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_() + out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_() + out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_() + if out_proj_bias is not None else None) + A_ref = A.detach().clone().requires_grad_() + B_ref = B.detach().clone().requires_grad_() if B is not None else None + C_ref = C.detach().clone().requires_grad_() if C is not None else None + D_ref = D.detach().clone().requires_grad_() + delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None + out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B, C, D, delta_bias=delta_bias, delta_softplus=True) + out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref, + delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref, + A_ref, B_ref, C_ref, D_ref, + delta_bias=delta_bias_ref, delta_softplus=True) + # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + # dt_u = delta * u + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + g = torch.randn_like(out) + out_ref.backward(g) + out.backward(g) + + print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}') + print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') + if not is_variable_B: + print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') + if not is_variable_C: + print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') + print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') + print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') + print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}') + print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}') + print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}') + print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}') + print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}') + + # assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) + # assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) + # assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) + # assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, + # atol=atolw if not is_variable_B else atol) + # assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, + # atol=atolw if not is_variable_C else atol) + # assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) + # assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) diff --git a/tests/ops/triton/test_selective_state_update.py b/tests/ops/triton/test_selective_state_update.py new file mode 100644 index 00000000..70a8d79d --- /dev/null +++ b/tests/ops/triton/test_selective_state_update.py @@ -0,0 +1,49 @@ +# Copyright (C) 2023, Tri Dao. + +import math + +import torch +import torch.nn.functional as F +import pytest + +from einops import rearrange + +from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize('itype', [torch.float16]) +@pytest.mark.parametrize("has_z", [False, True]) +# @pytest.mark.parametrize('has_z', [True]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +# @pytest.mark.parametrize("dstate", [16]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +# @pytest.mark.parametrize("dim", [2048]) +def test_causal_conv1d_update(dim, dstate, has_z, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + # set seed + torch.random.manual_seed(0) + batch_size = 2 + state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) + x = torch.randn(batch_size, dim, device=device, dtype=itype) + dt = torch.randn(batch_size, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(batch_size, dstate, device=device) + C = torch.randn(batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + if has_z: + z = torch.randn_like(x) + else: + z = None + state_ref = state.detach().clone() + out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) + out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)