From a40ff7436f78a0a20da593e87dfd3e3673c0819e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 2 Dec 2023 23:06:16 -0800 Subject: [PATCH] Code release --- .github/workflows/publish.yaml | 212 +++++++++ AUTHORS | 1 + LICENSE | 29 ++ README.md | 1 + causal_conv1d/__init__.py | 3 + causal_conv1d/causal_conv1d_interface.py | 104 +++++ csrc/causal_conv1d.cpp | 333 ++++++++++++++ csrc/causal_conv1d.h | 53 +++ csrc/causal_conv1d_bwd.cu | 525 +++++++++++++++++++++++ csrc/causal_conv1d_common.h | 64 +++ csrc/causal_conv1d_fwd.cu | 350 +++++++++++++++ csrc/causal_conv1d_update.cu | 96 +++++ csrc/static_switch.h | 25 ++ setup.py | 264 ++++++++++++ tests/test_causal_conv1d.py | 173 ++++++++ 15 files changed, 2233 insertions(+) create mode 100644 .github/workflows/publish.yaml create mode 100644 AUTHORS create mode 100644 LICENSE create mode 100644 README.md create mode 100644 causal_conv1d/__init__.py create mode 100644 causal_conv1d/causal_conv1d_interface.py create mode 100644 csrc/causal_conv1d.cpp create mode 100644 csrc/causal_conv1d.h create mode 100644 csrc/causal_conv1d_bwd.cu create mode 100644 csrc/causal_conv1d_common.h create mode 100644 csrc/causal_conv1d_fwd.cu create mode 100644 csrc/causal_conv1d_update.cu create mode 100644 csrc/static_switch.h create mode 100644 setup.py create mode 100644 tests/test_causal_conv1d.py diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml new file mode 100644 index 0000000..39fa328 --- /dev/null +++ b/.github/workflows/publish.yaml @@ -0,0 +1,212 @@ +# This workflow will: +# - Create a new Github release +# - Build wheels for supported architectures +# - Deploy the wheels to the Github release +# - Release the static code to PyPi +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries + +name: Build wheels and deploy + +on: + create: + tags: + - v* + +jobs: + + setup_release: + name: Create Release + runs-on: ubuntu-latest + steps: + - name: Get the tag version + id: extract_branch + run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} + shell: bash + + - name: Create Release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ steps.extract_branch.outputs.branch }} + release_name: ${{ steps.extract_branch.outputs.branch }} + + build_wheels: + name: Build Wheel + needs: setup_release + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the + # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. + os: [ubuntu-20.04] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.1', '2.2.0.dev20231127'] + cuda-version: ['11.8.0', '12.2.0'] + # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. + # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. + # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) + # when building without C++11 ABI and using it on nvcr images. + cxx11_abi: ['FALSE', 'TRUE'] + exclude: + # Pytorch <= 1.12 does not support Python 3.11 + - torch-version: '1.12.1' + python-version: '3.11' + # Pytorch >= 2.0 only supports Python >= 3.8 + - torch-version: '2.0.1' + python-version: '3.7' + - torch-version: '2.1.1' + python-version: '3.7' + - torch-version: '2.2.0.dev20231127' + python-version: '3.7' + # Pytorch <= 2.0 only supports CUDA <= 11.8 + - torch-version: '1.12.1' + cuda-version: '12.2.0' + - torch-version: '1.13.1' + cuda-version: '12.2.0' + - torch-version: '2.0.1' + cuda-version: '12.2.0' + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Set CUDA and PyTorch versions + run: | + echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV + + - name: Free up disk space + if: ${{ runner.os == 'Linux' }} + # https://github.com/easimon/maximize-build-space/blob/master/action.yml + # https://github.com/easimon/maximize-build-space/tree/test-report + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + + - name: Set up swap space + if: runner.os == 'Linux' + uses: pierotofy/set-swap-space@v1.0 + with: + swap-size-gb: 10 + + - name: Install CUDA ${{ matrix.cuda-version }} + if: ${{ matrix.cuda-version != 'cpu' }} + uses: Jimver/cuda-toolkit@v0.2.11 + id: cuda-toolkit + with: + cuda: ${{ matrix.cuda-version }} + linux-local-args: '["--toolkit"]' + # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 + # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }} + method: 'network' + # We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions, + # not just nvcc + # sub-packages: '["nvcc"]' + + - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} + run: | + pip install --upgrade pip + # If we don't install before installing Pytorch, we get error for torch 2.0.1 + # ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none) + pip install lit + # We want to figure out the CUDA version to download pytorch + # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 + # This code is ugly, maybe there's a better way to do this. + export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))") + if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then + pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} + else + pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} + fi + nvcc --version + python --version + python -c "import torch; print('PyTorch:', torch.__version__)" + python -c "import torch; print('CUDA:', torch.version.cuda)" + python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" + shell: + bash + + - name: Build wheel + run: | + # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 + # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 + # However this still fails so I'm using a newer version of setuptools + pip install setuptools==68.0.0 + pip install ninja packaging wheel + export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH + export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH + # Limit MAX_JOBS otherwise the github runner goes OOM + MAX_JOBS=2 CAUSAL_CONV1D_FORCE_BUILD="TRUE" CAUSAL_CONV1D_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist + tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }} + wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") + ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} + echo "wheel_name=${wheel_name}" >> $GITHUB_ENV + + - name: Log Built Wheels + run: | + ls dist + + - name: Get the tag version + id: extract_branch + run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} + + - name: Get Release with tag + id: get_current_release + uses: joutvhu/get-release@v1 + with: + tag_name: ${{ steps.extract_branch.outputs.branch }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Upload Release Asset + id: upload_release_asset + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.get_current_release.outputs.upload_url }} + asset_path: ./dist/${{env.wheel_name}} + asset_name: ${{env.wheel_name}} + asset_content_type: application/* + + publish_package: + name: Publish package + needs: [build_wheels] + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + pip install ninja packaging setuptools wheel twine + # We don't want to download anything CUDA-related here + pip install torch --index-url https://download.pytorch.org/whl/cpu + + - name: Build core package + env: + CAUSAL_CONV1D_SKIP_CUDA_BUILD: "TRUE" + run: | + python setup.py sdist --dist-dir=dist + + - name: Deploy + env: + TWINE_USERNAME: "__token__" + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + run: | + python -m twine upload dist/* diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 0000000..8819385 --- /dev/null +++ b/AUTHORS @@ -0,0 +1 @@ +Tri Dao, tri@tridao.me diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..5860e4b --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..4e90542 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# Causal depthwise conv1d in CUDA with a PyTorch interface diff --git a/causal_conv1d/__init__.py b/causal_conv1d/__init__.py new file mode 100644 index 0000000..cc4d610 --- /dev/null +++ b/causal_conv1d/__init__.py @@ -0,0 +1,3 @@ +__version__ = "1.0.0" + +from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update diff --git a/causal_conv1d/causal_conv1d_interface.py b/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000..f66143c --- /dev/null +++ b/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,104 @@ +# Copyright (c) 2023, Tri Dao. + +import torch +import torch.nn.functional as F + + +import causal_conv1d_cuda + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias=None, activation=None): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + ctx.save_for_backward(x, weight, bias) + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation) + return out + + @staticmethod + def backward(ctx, dout): + x, weight, bias = ctx.saved_tensors + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd( + x, weight, bias, dout, None, ctx.activation + ) + return dx, dweight, dbias if bias is not None else None, None + + +def causal_conv1d_fn(x, weight, bias=None, activation=None): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply(x, weight, bias, activation) + + +def causal_conv1d_ref(x, weight, bias=None, activation=None): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + out = out[..., :seqlen] + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): + """ + x: (batch, dim) + conv_state: (batch, dim, width) + weight: (dim, width) + bias: (dim,) + + out: (batch, dim) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation) + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None): + """ + x: (batch, dim) + conv_state: (batch, dim, width) + weight: (dim, width) + bias: (dim,) + + out: (batch, dim) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + batch, dim = x.shape + width = weight.shape[1] + assert conv_state.shape == (batch, dim, width) + assert weight.shape == (dim, width) + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = x + out = torch.sum(conv_state * weight, dim=-1) # (B D) + if bias is not None: + out += bias + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/csrc/causal_conv1d.cpp b/csrc/causal_conv1d.cpp new file mode 100644 index 0000000..1c80516 --- /dev/null +++ b/csrc/causal_conv1d.cpp @@ -0,0 +1,333 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include +#include + +#include "causal_conv1d.h" + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Half) { \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::BFloat16) { \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template +void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template +void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template +void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +void set_conv_params_fwd(ConvParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t width, + // device pointers + const at::Tensor x, + const at::Tensor weight, + const at::Tensor out, + void* bias_ptr, + bool silu_activation) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.width = width; + + params.silu_activation = silu_activation; + + // Set the pointers and strides. + params.x_ptr = x.data_ptr(); + params.weight_ptr = weight.data_ptr(); + params.bias_ptr = bias_ptr; + params.out_ptr = out.data_ptr(); + // All stride are in elements, not bytes. + params.x_batch_stride = x.stride(0); + params.x_c_stride = x.stride(1); + params.x_l_stride = x.stride(-1); + params.weight_c_stride = weight.stride(0); + params.weight_width_stride = weight.stride(1); + params.out_batch_stride = out.stride(0); + params.out_c_stride = out.stride(1); + params.out_l_stride = out.stride(-1); +} + + +void set_conv_params_bwd(ConvParamsBwd ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t width, + // device pointers + const at::Tensor x, + const at::Tensor weight, + void* bias_ptr, + const at::Tensor dout, + const at::Tensor dx, + const at::Tensor dweight, + void* dbias_ptr, + bool silu_activation) { + // Pass in "dout" instead of "out", we're not gonna use "out" at all. + set_conv_params_fwd(params, batch, dim, seqlen, width, + x, weight, dout, bias_ptr, silu_activation); + + // Set the pointers and strides. + params.dout_ptr = dout.data_ptr(); + params.dx_ptr = dx.data_ptr(); + params.dweight_ptr = dweight.data_ptr(); + params.dbias_ptr = dbias_ptr; + // All stride are in elements, not bytes. + params.dout_batch_stride = dout.stride(0); + params.dout_c_stride = dout.stride(1); + params.dout_l_stride = dout.stride(2); + params.dweight_c_stride = dweight.stride(0); + params.dweight_width_stride = dweight.stride(1); + params.dx_batch_stride = dx.stride(0); + params.dx_c_stride = dx.stride(1); + params.dx_l_stride = dx.stride(2); +} + +at::Tensor +causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, + const c10::optional &bias_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int width = weight.size(-1); + + CHECK_SHAPE(x, batch_size, dim, seqlen); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); + const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; + + if (is_channel_last) { + TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); + } + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, + bias_.has_value() ? bias_.value().data_ptr() : nullptr, + silu_activation); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { + DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] { + if (!is_channel_last) { + causal_conv1d_fwd_cuda(params, stream); + } else { + causal_conv1d_channellast_fwd_cuda(params, stream); + } + }); + }); + return out; +} + +std::vector +causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight, + const c10::optional &bias_, + at::Tensor &dout, + c10::optional &dx_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + TORCH_CHECK(dout.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int width = weight.size(-1); + + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + CHECK_SHAPE(x, batch_size, dim, seqlen); + CHECK_SHAPE(weight, dim, width); + CHECK_SHAPE(dout, batch_size, dim, seqlen); + + TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); + const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; + if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); } + if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); } + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + at::Tensor dx; + if (dx_.has_value()) { + dx = dx_.value(); + TORCH_CHECK(dx.scalar_type() == input_type); + TORCH_CHECK(dx.is_cuda()); + CHECK_SHAPE(dx, batch_size, dim, seqlen); + if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); } + if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); } + } else { + dx = torch::empty_like(x); + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + + at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat)); + at::Tensor dbias; + if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); } + + ConvParamsBwd params; + set_conv_params_bwd(params, batch_size, dim, seqlen, width, + x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr, + dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr, + silu_activation); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] { + DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] { + if (!is_channel_last) { + causal_conv1d_bwd_cuda(params, stream); + } else { + causal_conv1d_channellast_bwd_cuda(params, stream); + } + }); + }); + return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias}; +} + +at::Tensor +causal_conv1d_update(const at::Tensor &x, + const at::Tensor &conv_state, + const at::Tensor &weight, + const c10::optional &bias_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + TORCH_CHECK(conv_state.scalar_type() == input_type); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(conv_state.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int width = weight.size(-1); + + CHECK_SHAPE(x, batch_size, dim); + CHECK_SHAPE(conv_state, batch_size, dim, width); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, + bias_.has_value() ? bias_.value().data_ptr() : nullptr, + silu_activation); + params.conv_state_ptr = conv_state.data_ptr(); + // All stride are in elements, not bytes. + params.conv_state_batch_stride = conv_state.stride(0); + params.conv_state_c_stride = conv_state.stride(1); + params.conv_state_l_stride = conv_state.stride(2); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { + DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] { + causal_conv1d_update_cuda(params, stream); + }); + }); + return out; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward"); + m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward"); + m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update"); +} diff --git a/csrc/causal_conv1d.h b/csrc/causal_conv1d.h new file mode 100644 index 0000000..844ed92 --- /dev/null +++ b/csrc/causal_conv1d.h @@ -0,0 +1,53 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ConvParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, width; + bool silu_activation; + + index_t x_batch_stride; + index_t x_c_stride; + index_t x_l_stride; + index_t weight_c_stride; + index_t weight_width_stride; + index_t out_batch_stride; + index_t out_c_stride; + index_t out_l_stride; + + index_t conv_state_batch_stride; + index_t conv_state_c_stride; + index_t conv_state_l_stride; + + // Common data pointers. + void *__restrict__ x_ptr; + void *__restrict__ weight_ptr; + void *__restrict__ bias_ptr; + void *__restrict__ out_ptr; + + void *__restrict__ conv_state_ptr; +}; + +struct ConvParamsBwd: public ConvParamsBase { + index_t dx_batch_stride; + index_t dx_c_stride; + index_t dx_l_stride; + index_t dweight_c_stride; + index_t dweight_width_stride; + index_t dout_batch_stride; + index_t dout_c_stride; + index_t dout_l_stride; + + // Common data pointers. + void *__restrict__ dx_ptr; + void *__restrict__ dweight_ptr; + void *__restrict__ dbias_ptr; + void *__restrict__ dout_ptr; +}; + diff --git a/csrc/causal_conv1d_bwd.cu b/csrc/causal_conv1d_bwd.cu new file mode 100644 index 0000000..6660975 --- /dev/null +++ b/csrc/causal_conv1d_bwd.cu @@ -0,0 +1,525 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include +#include + +#include "causal_conv1d.h" +#include "causal_conv1d_common.h" +#include "static_switch.h" + +template +struct Causal_conv1d_bwd_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr bool kSiluAct = kSiluAct_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static_assert(kWidth <= kNElts); + // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits + // (since then we'd have 8 values of float, and each round we can exchange 4 floats). + static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + using BlockReduceFloatT = cub::BlockReduce; + static constexpr int kSmemIOSize = kIsVecLoad + ? 0 + : std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); + static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1); + static constexpr int kSmemSize = std::max({kSmemExchangeSize, + int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_bwd_kernel(ConvParamsBwd params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr bool kSiluAct = Ktraits::kSiluAct; + constexpr int kNElts = Ktraits::kNElts; + constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds; + constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_vec = reinterpret_cast(smem_); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_store_vec = reinterpret_cast(smem_); + vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + vec_t *smem_exchange_x = reinterpret_cast(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds; + auto& smem_reduce_float = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + dim_id * params.x_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + dim_id * params.weight_c_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + dim_id * params.dout_c_stride; + input_t *dx = reinterpret_cast(params.dx_ptr) + batch_id * params.dx_batch_stride + + dim_id * params.dx_c_stride; + float *dweight = reinterpret_cast(params.dweight_ptr) + dim_id * params.dweight_c_stride; + float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[dim_id]); + + // Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0. + if (tidx == 0) { + if constexpr (!kSiluAct) { + input_t zeros[kNElts] = {0}; + smem_exchange[0] = reinterpret_cast(zeros)[0]; + } else { + float zeros[kNElts] = {0}; + #pragma unroll + for (int r = 0; r < kNExchangeRounds; ++r) { + smem_exchange[r * kNThreads] = reinterpret_cast(zeros)[r]; + } + } + } + + float weight_vals[kWidth]; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; } + + float dweight_vals[kWidth] = {0}; + float dbias_val = 0; + + constexpr int kChunkSize = kNThreads * kNElts; + const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; + x += (n_chunks - 1) * kChunkSize; + dout += (n_chunks - 1) * kChunkSize; + dx += (n_chunks - 1) * kChunkSize; + for (int chunk = n_chunks - 1; chunk >= 0; --chunk) { + input_t x_vals_load[2 * kNElts] = {0}; + input_t dout_vals_load[2 * kNElts] = {0}; + if constexpr(kIsVecLoad) { + Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); + Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(dout), *reinterpret_cast(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts); + } else { + __syncthreads(); + Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); + __syncthreads(); + Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize); + } + float dout_vals[2 * kNElts], x_vals[2 * kNElts]; + if constexpr (!kSiluAct) { + __syncthreads(); + // Thread 0 don't write yet, so that thread kNThreads - 1 can read + // the first elements of the next chunk. + if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast(dout_vals_load)[0]; } + __syncthreads(); + reinterpret_cast(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0]; + __syncthreads(); + // Now thread 0 can write the first elements of the current chunk. + if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast(dout_vals_load)[0]; } + #pragma unroll + for (int i = 0; i < 2 * kNElts; ++i) { + dout_vals[i] = float(dout_vals_load[i]); + x_vals[i] = float(x_vals_load[i]); + } + } else { + if (tidx == 0 && chunk > 0) { + if constexpr(kIsVecLoad) { + reinterpret_cast(x_vals_load)[0] = reinterpret_cast(x)[-1]; + } else { + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; } + } + } + } + __syncthreads(); + smem_exchange_x[tidx] = reinterpret_cast(x_vals_load)[1]; + __syncthreads(); + if (tidx > 0) { reinterpret_cast(x_vals_load)[0] = smem_exchange_x[tidx - 1]; } + #pragma unroll + for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } + // Recompute the output + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + float out_val = bias_val; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; + } + float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val)); + dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val + * (1.0f + out_val * (1.0f - out_sigmoid_val)); + } + // Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange + // if input_t is 16 bits (since then we'd have 8 values of float) + __syncthreads(); + // Thread 0 don't write yet, so that thread kNThreads - 1 can read + // the first elements of the next chunk. + if (tidx > 0) { + #pragma unroll + for (int r = 0; r < kNExchangeRounds; ++r) { + smem_exchange[r * kNThreads + tidx] = reinterpret_cast(dout_vals)[r]; + } + } + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNExchangeRounds; ++r) { + reinterpret_cast(dout_vals)[kNExchangeRounds + r] + = smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)]; + } + __syncthreads(); + // Now thread 0 can write the first elements of the current chunk. + if (tidx == 0) { + #pragma unroll + for (int r = 0; r < kNExchangeRounds; ++r) { + smem_exchange[r * kNThreads + tidx] = reinterpret_cast(dout_vals)[r]; + } + } + } + dout -= kChunkSize; + x -= kChunkSize; + + #pragma unroll + for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; } + + float dx_vals[kNElts] = {0}; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1]; + } + } + + input_t dx_vals_store[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; } + if constexpr(kIsVecLoad) { + Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(dx), reinterpret_cast(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); + } else { + Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize); + } + dx -= kChunkSize; + + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1]; + } + } + } + + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + __syncthreads(); + dweight_vals[w] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]); + if (tidx == 0) { + atomicAdd(&reinterpret_cast(dweight)[w * params.dweight_width_stride], dweight_vals[w]); + } + } + if (params.bias_ptr != nullptr) { + __syncthreads(); + dbias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val); + if (tidx == 0) { + atomicAdd(&reinterpret_cast(params.dbias_ptr)[dim_id], dbias_val); + } + } +} + +template +void causal_conv1d_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) { + static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; + BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { + BOOL_SWITCH(params.silu_activation, kSiluAct, [&] { + using Ktraits = Causal_conv1d_bwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize; + dim3 grid(params.batch, params.dim); + auto kernel = &causal_conv1d_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); +} + +template +void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream); + } +} + +template +struct Causal_conv1d_channellast_bwd_kernel_traits { + // The cache line is 128 bytes, and we try to read 16 bytes per thread. + // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. + // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 + // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr bool kSiluAct = kSiluAct_; + static constexpr int kNThreads = kNThreads_; + static_assert(kNThreads % 32 == 0); + static constexpr int kNWarps = kNThreads / 32; + static constexpr int kWidth = kWidth_; + static constexpr int kChunkSizeL = kChunkSizeL_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static constexpr int kNEltsPerRow = 128 / kNBytes; + static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now + static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); + static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now + static_assert(kNColsPerWarp * kNThreadsPerRow == 32); + static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; + static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; + static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + // using BlockLoadT = cub::BlockLoad; + // using BlockStoreT = cub::BlockStore; + // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), + // sizeof(typename BlockStoreT::TempStorage)}); + // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr bool kSiluAct = Ktraits::kSiluAct; + constexpr int kNElts = Ktraits::kNElts; + constexpr int kNWarp = Ktraits::kNWarps; + constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; + constexpr int kLPerLoad = Ktraits::kNColsPerLoad; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + __shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts]; + __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts]; + + const int tid = threadIdx.x; + const int l_idx = tid / kNThreadsPerC; + const int c_idx = tid % kNThreadsPerC; + const int batch_id = blockIdx.x; + const int chunk_l_id = blockIdx.y; + const int chunk_c_id = blockIdx.z; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + weight_t *weight = reinterpret_cast(params.weight_ptr) + + chunk_c_id * kChunkSizeC * params.weight_c_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + input_t *dx = reinterpret_cast(params.dx_ptr) + batch_id * params.dx_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + float *dweight = reinterpret_cast(params.dweight_ptr) + + chunk_c_id * kChunkSizeC * params.dweight_c_stride; + + #pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t dout_vals_load[kNElts] = {0}; + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(dout_vals_load)[0] = *reinterpret_cast(dout + l * kLPerLoad * params.dout_l_stride); + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); + } + reinterpret_cast(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(dout_vals_load)[0]; + reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; + } + // Load the elements from the previous chunk or next chunk that are needed for convolution. + if (l_idx < kWidth - 1) { + input_t dout_vals_load[kNElts] = {0}; + input_t x_vals_load[kNElts] = {0}; + if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(dout_vals_load)[0] = *reinterpret_cast(dout + kChunkSizeL * params.dout_l_stride); + } + if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 + && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); + } + reinterpret_cast(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast(dout_vals_load)[0]; + reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; + } + // Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs + if constexpr (kSiluAct) { + if (l_idx < kWidth - 1) { + input_t x_vals_load[kNElts] = {0}; + if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + kChunkSizeL * params.x_l_stride); + } + reinterpret_cast(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; + } + } + + __syncthreads(); + + constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); + static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); + constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; + static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); + // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity + static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); + static_assert((kLPerThread & (kLPerThread - 1)) == 0); + static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); + static_assert(kNThreadsPerRow <= 32); + + const int row_idx = tid / kNThreadsPerRow; + const int col_idx = tid % kNThreadsPerRow; + + float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); + float weight_vals[kWidth] = {0}; + if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; + } + } + float dout_vals[kLPerThread + kWidth - 1]; + float x_vals[kWidth - 1 + kLPerThread + kWidth - 1]; + #pragma unroll + for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { + dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]); + x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); + } + + if constexpr (kSiluAct) { // Recompute the output + #pragma unroll + for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) { + x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); + } + #pragma unroll + for (int i = 0; i < kLPerThread + kWidth - 1; ++i) { + float out_val = bias_val; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { out_val += weight_vals[w] * x_vals[i + w]; } + float out_val_sigmoid = 1.f / (1.f + expf(-out_val)); + dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid)); + } + } + + float dweight_vals[kWidth] = {0}; + SumOp sum_op; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { dweight_vals[w] += x_vals[i + w] * dout_vals[i]; } + dweight_vals[w] = Allreduce::run(dweight_vals[w], sum_op); + if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) { + atomicAdd(&reinterpret_cast(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]); + } + } + + if (params.bias_ptr != nullptr) { + float dbias_val = 0.f; + for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; } + dbias_val = Allreduce::run(dbias_val, sum_op); + if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) { + atomicAdd(&reinterpret_cast(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val); + } + } + + float dx_vals[kLPerThread] = {0}; + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { + #pragma unroll + for (int w = 0; w < kWidth; ++w) { dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w]; } + } + // Since kNThreadsPerRow is a power of 2 and <= 32, we only need syncwarp and not syncthreads. + __syncwarp(); + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = dx_vals[i]; } + __syncthreads(); + + #pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t dx_vals_store[kNElts]; + reinterpret_cast(dx_vals_store)[0] = reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx]; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + *reinterpret_cast(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast(dx_vals_store)[0]; + } + } + +} + +template +void causal_conv1d_channellast_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.silu_activation, kSiluAct, [&] { + using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; + const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; + dim3 grid(params.batch, n_chunks_L, n_chunks_C); + dim3 block(Ktraits::kNThreads); + auto kernel = &causal_conv1d_channellast_bwd_kernel; + // if (kSmemSize >= 48 * 1024) { + // C10_CUDA_CHECK(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + // } + // kernel<<>>(params); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream); + } +} + +template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); + +template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/causal_conv1d_common.h b/csrc/causal_conv1d_common.h new file mode 100644 index 0000000..8dd6a33 --- /dev/null +++ b/csrc/causal_conv1d_common.h @@ -0,0 +1,64 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ inline T operator()(T const & x, T const & y) { return x + y; } +}; + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +template<> +struct Allreduce<2> { +template +static __device__ inline T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; diff --git a/csrc/causal_conv1d_fwd.cu b/csrc/causal_conv1d_fwd.cu new file mode 100644 index 0000000..74a1459 --- /dev/null +++ b/csrc/causal_conv1d_fwd.cu @@ -0,0 +1,350 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include + +#include "causal_conv1d.h" +#include "causal_conv1d_common.h" +#include "static_switch.h" + +template +struct Causal_conv1d_fwd_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static_assert(kWidth <= kNElts); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + static constexpr int kSmemIOSize = kIsVecLoad + ? 0 + : std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); + static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_fwd_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_vec = reinterpret_cast(smem_); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_store_vec = reinterpret_cast(smem_); + vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. + if (tidx == 0) { + input_t zeros[kNElts] = {0}; + smem_exchange[kNThreads - 1] = reinterpret_cast(zeros)[0]; + } + + float weight_vals[kWidth]; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + + constexpr int kChunkSize = kNThreads * kNElts; + const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; + for (int chunk = 0; chunk < n_chunks; ++chunk) { + input_t x_vals_load[2 * kNElts] = {0}; + if constexpr(kIsVecLoad) { + Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); + } else { + __syncthreads(); + Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); + } + x += kChunkSize; + __syncthreads(); + // Thread kNThreads - 1 don't write yet, so that thread 0 can read + // the last elements of the previous chunk. + if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + __syncthreads(); + reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; + __syncthreads(); + // Now thread kNThreads - 1 can write the last elements of the current chunk. + if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + + float x_vals[2 * kNElts]; + #pragma unroll + for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } + + float out_vals[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = bias_val; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; + } + } + + if (params.silu_activation) { + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); + } + } + + input_t out_vals_store[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } + if constexpr(kIsVecLoad) { + Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); + } else { + Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); + } + out += kChunkSize; + } +} + +template +void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; + BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { + using Ktraits = Causal_conv1d_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize; + dim3 grid(params.batch, params.dim); + auto kernel = &causal_conv1d_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } +} + +template +struct Causal_conv1d_channellast_fwd_kernel_traits { + // The cache line is 128 bytes, and we try to read 16 bytes per thread. + // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. + // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 + // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static_assert(kNThreads % 32 == 0); + static constexpr int kNWarps = kNThreads / 32; + static constexpr int kWidth = kWidth_; + static constexpr int kChunkSizeL = kChunkSizeL_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static constexpr int kNEltsPerRow = 128 / kNBytes; + static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now + static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); + static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now + static_assert(kNColsPerWarp * kNThreadsPerRow == 32); + static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; + static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; + static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + // using BlockLoadT = cub::BlockLoad; + // using BlockStoreT = cub::BlockStore; + // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), + // sizeof(typename BlockStoreT::TempStorage)}); + // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + constexpr int kNWarp = Ktraits::kNWarps; + constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; + constexpr int kLPerLoad = Ktraits::kNColsPerLoad; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts]; + + const int tid = threadIdx.x; + const int l_idx = tid / kNThreadsPerC; + const int c_idx = tid % kNThreadsPerC; + const int batch_id = blockIdx.x; + const int chunk_l_id = blockIdx.y; + const int chunk_c_id = blockIdx.z; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + weight_t *weight = reinterpret_cast(params.weight_ptr) + + chunk_c_id * kChunkSizeC * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + + #pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); + } + reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; + } + // Load the elements from the previous chunk that are needed for convolution. + if (l_idx < kWidth - 1) { + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 + && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); + } + reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; + } + + __syncthreads(); + + constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); + static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); + constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; + static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); + // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity + static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); + static_assert((kLPerThread & (kLPerThread - 1)) == 0); + static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); + static_assert(kNThreadsPerRow <= 32); + + const int row_idx = tid / kNThreadsPerRow; + const int col_idx = tid % kNThreadsPerRow; + + float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); + float weight_vals[kWidth] = {0}; + if (chunk_c_id + kChunkSizeC + row_idx < params.dim) { + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; + } + } + float x_vals[kWidth - 1 + kLPerThread]; + #pragma unroll + for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { + x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); + } + + float out_vals[kLPerThread]; + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { + out_vals[i] = bias_val; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { out_vals[i] += weight_vals[w] * x_vals[i + w]; } + if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } + } + + // Since kNThreadsPerRow is a power of 2 and <= 32, we only need syncwarp and not syncthreads. + __syncwarp(); + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; } + __syncthreads(); + + #pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t out_vals_store[kNElts]; + reinterpret_cast(out_vals_store)[0] = reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx]; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + *reinterpret_cast(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast(out_vals_store)[0]; + } + } + +} + +template +void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; + const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; + // printf("n_chunks_L: %d, n_chunks_C: %d\n", n_chunks_L, n_chunks_C); + dim3 grid(params.batch, n_chunks_L, n_chunks_C); + dim3 block(Ktraits::kNThreads); + auto kernel = &causal_conv1d_channellast_fwd_kernel; + // if (kSmemSize >= 48 * 1024) { + // C10_CUDA_CHECK(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + // } + // kernel<<>>(params); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } +} + +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/causal_conv1d_update.cu b/csrc/causal_conv1d_update.cu new file mode 100644 index 0000000..713e0ac --- /dev/null +++ b/csrc/causal_conv1d_update.cu @@ -0,0 +1,96 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include + +#include "causal_conv1d.h" +#include "causal_conv1d_common.h" +#include "static_switch.h" + +template +struct Causal_conv1d_update_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_update_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y * kNThreads + tidx; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride + + channel_id * params.conv_state_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + float weight_vals[kWidth] = {0}; + if (channel_id < params.dim) { + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + } + + float x_vals[kWidth] = {0}; + if (channel_id < params.dim) { + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } + x_vals[kWidth - 1] = float(x[0]); + #pragma unroll + for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } + } + + float out_val = bias_val; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } + if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } + if (channel_id < params.dim) { out[0] = input_t(out_val); } +} + +template +void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + using Ktraits = Causal_conv1d_update_kernel_traits; + dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); + auto kernel = &causal_conv1d_update_kernel; + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); + } +} + +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/static_switch.h b/csrc/static_switch.h new file mode 100644 index 0000000..0f4ad3e --- /dev/null +++ b/csrc/static_switch.h @@ -0,0 +1,25 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..12e36bf --- /dev/null +++ b/setup.py @@ -0,0 +1,264 @@ +# Copyright (c) 2023, Tri Dao. +import sys +import warnings +import os +import re +import ast +from pathlib import Path +from packaging.version import parse, Version +import platform + +from setuptools import setup, find_packages +import subprocess + +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import torch +from torch.utils.cpp_extension import ( + BuildExtension, + CppExtension, + CUDAExtension, + CUDA_HOME, +) + + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + +PACKAGE_NAME = "causal_conv1d" + +BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}" + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE" + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith("linux"): + return "linux_x86_64" + elif sys.platform == "darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) + return f"macosx_{mac_version}_x86_64" + elif sys.platform == "win32": + return "win_amd64" + else: + raise ValueError("Unsupported platform: {}".format(sys.platform)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + return nvcc_extra_args + ["--threads", "4"] + + +cmdclass = {} +ext_modules = [] + +if not SKIP_CUDA_BUILD: + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + + check_if_cuda_home_none("causal_conv1d") + # Check, if CUDA11 is installed for compute capability 8.0 + cc_flag = [] + if CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("11.6"): + raise RuntimeError( + "causal_conv1d is only supported on CUDA 11.6 and above. " + "Note: make sure nvcc has a supported version by running nvcc -V." + ) + + cc_flag.append("-gencode") + cc_flag.append("arch=compute_70,code=sm_70") + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,code=sm_80") + if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # torch._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + + ext_modules.append( + CUDAExtension( + name="causal_conv1d_cuda", + sources=[ + "csrc/causal_conv1d.cpp", + "csrc/causal_conv1d_fwd.cu", + "csrc/causal_conv1d_bwd.cu", + "csrc/causal_conv1d_update.cu", + ], + extra_compile_args={ + "cxx": ["-O3"], + "nvcc": append_nvcc_threads( + [ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + "-lineinfo", + ] + + cc_flag + ), + }, + include_dirs=[this_dir], + ) + ) + + +def get_package_version(): + with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f: + version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) + public_version = ast.literal_eval(version_match.group(1)) + local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION") + if local_version: + return f"{public_version}+{local_version}" + else: + return str(public_version) + + +def get_wheel_url(): + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build torch, not the one currently installed + # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + torch_cuda_version = parse(torch.version.cuda) + torch_version_raw = parse(torch.__version__) + # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 + # to save CI time. Minor versions should be compatible. + torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + causal_conv1d_version = get_package_version() + # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" + torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" + cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() + + # Determine wheel URL based on CUDA version, torch version, python version and OS + wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" + wheel_url = BASE_WHEEL_URL.format( + tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename + ) + return wheel_url, wheel_filename + + +class CachedWheelsCommand(_bdist_wheel): + """ + The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot + find an existing wheel (which is currently the case for all installs). We use + the environment parameters to detect whether there is already a pre-built version of a compatible + wheel available and short-circuits the standard full build pipeline. + """ + + def run(self): + if FORCE_BUILD: + return super().run() + + wheel_url, wheel_filename = get_wheel_url() + print("Guessing wheel URL: ", wheel_url) + try: + urllib.request.urlretrieve(wheel_url, wheel_filename) + + # Make the archive + # Lifted from the root wheel processing command + # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + + wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + print("Raw wheel path", wheel_path) + os.rename(wheel_filename, wheel_path) + except urllib.error.HTTPError: + print("Precompiled wheel not found. Building from source...") + # If the wheel could not be downloaded, build from source + super().run() + + +setup( + name=PACKAGE_NAME, + version=get_package_version(), + packages=find_packages( + exclude=( + "build", + "csrc", + "include", + "tests", + "dist", + "docs", + "benchmarks", + "causal_conv1d.egg-info", + ) + ), + author="Tri Dao", + author_email="tri@tridao.me", + description="Causal depthwise conv1d in CUDA, with a PyTorch interface", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/Dao-AILab/causal-conv1d", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + ], + ext_modules=ext_modules, + cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} + if ext_modules + else { + "bdist_wheel": CachedWheelsCommand, + }, + python_requires=">=3.7", + install_requires=[ + "torch", + "packaging", + "ninja", + ], +) diff --git a/tests/test_causal_conv1d.py b/tests/test_causal_conv1d.py new file mode 100644 index 0000000..6e5985c --- /dev/null +++ b/tests/test_causal_conv1d.py @@ -0,0 +1,173 @@ +# Copyright (C) 2023, Tri Dao. + +import math + +import torch +import pytest + +from einops import rearrange + +from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref +from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref + + +@pytest.mark.parametrize("channel_last", [False, True]) +# @pytest.mark.parametrize('channel_last', [True]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize('itype', [torch.float16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +# @pytest.mark.parametrize('silu_activation', [True]) +@pytest.mark.parametrize("has_bias", [False, True]) +# @pytest.mark.parametrize('has_bias', [True]) +@pytest.mark.parametrize("width", [2, 3, 4]) +# @pytest.mark.parametrize('width', [2]) +@pytest.mark.parametrize( + "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] +) +# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) +# @pytest.mark.parametrize('seqlen', [128]) +def test_causal_conv1d(seqlen, width, has_bias, silu_activation, itype, channel_last): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + # set seed + torch.random.manual_seed(0) + batch_size = 2 + # batch_size = 1 + dim = 4096 + 32 # Try dim not divisible by 64 + # dim = 64 + if not channel_last: + x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() + else: + x = rearrange( + torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" + ).requires_grad_() + weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) + if has_bias: + bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + else: + bias = None + x_ref = x.detach().clone().requires_grad_() + weight_ref = weight.detach().clone().requires_grad_() + bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None + activation = None if not silu_activation else "silu" + out = causal_conv1d_fn(x, weight, bias, activation=activation) + out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + g = torch.randn_like(out) + out_ref.backward(g) + out.backward(g) + + print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") + print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") + if has_bias: + print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") + + assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) + assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw) + if has_bias: + assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize('itype', [torch.float16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +# @pytest.mark.parametrize('silu_activation', [False]) +@pytest.mark.parametrize("has_bias", [False, True]) +# @pytest.mark.parametrize('has_bias', [True]) +@pytest.mark.parametrize("width", [2, 3, 4]) +# @pytest.mark.parametrize('width', [2]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +# @pytest.mark.parametrize("dim", [2048]) +def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + # set seed + torch.random.manual_seed(0) + batch_size = 2 + # batch_size = 1 + # dim = 64 + x = torch.randn(batch_size, dim, device=device, dtype=itype) + conv_state = torch.randn(batch_size, dim, width, device=device, dtype=itype) + weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) + if has_bias: + bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + else: + bias = None + conv_state_ref = conv_state.detach().clone() + activation = None if not silu_activation else "silu" + out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation) + out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.equal(conv_state, conv_state_ref) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +# @pytest.mark.parametrize("channel_last", [False, True]) +@pytest.mark.parametrize('channel_last', [True]) +# @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('itype', [torch.bfloat16]) +# @pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize('silu_activation', [True]) +# @pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize('has_bias', [True]) +# @pytest.mark.parametrize("width", [2, 3, 4]) +@pytest.mark.parametrize('width', [4]) +@pytest.mark.parametrize( + # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] + "seqlen", [2048] +) +# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) +# @pytest.mark.parametrize('seqlen', [128]) +def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last): + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 2 + # batch_size = 1 + dim = 4096 + 32 # Try dim not divisible by 64 + # dim = 64 + if not channel_last: + x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() + else: + x = rearrange( + torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" + ).requires_grad_() + weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) + if has_bias: + bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + else: + bias = None + activation = None if not silu_activation else "silu" + out0 = causal_conv1d_fn(x, weight, bias, activation=activation) + g = torch.randn_like(out0) + dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g) + dw_atol = 1e-4 + db_atol = 1e-4 + + for i in range(10000): + out = causal_conv1d_fn(x, weight, bias, activation=activation) + dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g) + dw_equal = torch.allclose(dw, dw0, atol=dw_atol) + # if not dw_equal: + # breakpoint() + if has_bias: + db_equal = torch.allclose(db, db0, atol=db_atol) + # if not db_equal: + # breakpoint() + assert torch.equal(out, out0) + assert torch.equal(dx, dx0) + assert dw_equal + if has_bias: + assert dw_equal