Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM #7651

Merged
merged 50 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
59e6abf
Migrate mamba_ssm and causal_conv1d kernels to vLLM
mzusman Aug 19, 2024
d2348ec
Casual conv1d compiles
mzusman Aug 20, 2024
66ee5af
Add casual_conv1d to _custom_ops
mzusman Aug 20, 2024
7a0d206
Add mamba ops and triton kernels
mzusman Aug 20, 2024
145b6b7
Add casual_conv1d update
mzusman Aug 20, 2024
2bdd7f5
setup selective scan fwd pass
mzusman Aug 20, 2024
e25dbfe
Format
mzusman Aug 20, 2024
64b6160
Do not have a mamba layer for now, push in a future PR
mzusman Aug 20, 2024
2ff36cb
Format
mzusman Aug 20, 2024
5f9c383
Take off mamba from image and requirements
mzusman Aug 20, 2024
ac8354e
Add tests
mzusman Aug 20, 2024
ea80282
Some small fixes, tests still do not pass
mzusman Aug 22, 2024
2f15495
Fix tests
mzusman Aug 22, 2024
b51fd28
Causal conv1d tests are passing
mzusman Aug 22, 2024
0cc2252
Import
mzusman Aug 22, 2024
d65dfb6
Tests
mzusman Aug 22, 2024
e7b2b32
Format
mzusman Aug 22, 2024
2c9fe00
Cleanup
mzusman Aug 22, 2024
c82cc30
Align with main
mzusman Aug 22, 2024
6c83e5f
Format
mzusman Aug 22, 2024
cd78cf6
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman Aug 22, 2024
b6a00cb
Add init py files
mzusman Aug 22, 2024
ef69b6c
Move kernels to cuda only
mzusman Aug 22, 2024
152f331
Revert "Move kernels to cuda only"
mzusman Aug 22, 2024
39f0fa0
move kernels to if cuda
mzusman Aug 22, 2024
42f94b7
Fix tests
mzusman Aug 22, 2024
f050781
Revert formating
mzusman Aug 25, 2024
c8ffba5
Format
mzusman Aug 25, 2024
04f947b
Add comments on adapted from mamba/casual conv1d repos
mzusman Aug 25, 2024
732db18
pare down number of w/i dtype combinations
mzusman Aug 25, 2024
fdca1ff
Clean up not used
mzusman Aug 25, 2024
fe70a39
Rename typo
mzusman Aug 25, 2024
9a0e538
Add comment on einops
mzusman Aug 25, 2024
619a40a
Remove requirement for einops
mzusman Aug 25, 2024
5d0d2db
Fix tests after paring down kernels
mzusman Aug 25, 2024
c622375
format
mzusman Aug 25, 2024
cdc9205
Fix typo
mzusman Aug 25, 2024
42d9c59
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman Aug 25, 2024
308c922
register meta functions to the kernels
mzusman Aug 25, 2024
d921a48
Revert "register meta functions to the kernels"
mzusman Aug 25, 2024
a8078e7
move to ifndef ROCm
mzusman Aug 26, 2024
2ca8db7
Format
mzusman Aug 26, 2024
abf02fa
Reduce combinations of bool switch to reduce wheel size
mzusman Aug 27, 2024
633225c
Fix, use float as weight dtype
mzusman Aug 27, 2024
ec0112b
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman Aug 28, 2024
1f35bbe
Take down seq_pos_idx, not used atm, will comeback in a following PR
mzusman Aug 28, 2024
bed44c4
Add comments and guard checks on disabled "features"
mzusman Aug 28, 2024
950701a
Fix header file
mzusman Aug 28, 2024
4e5d6b4
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman Aug 28, 2024
d23a429
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman Aug 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_MakeAvailable(cutlass)

list(APPEND VLLM_EXT_SRC
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
mzusman marked this conversation as resolved.
Show resolved Hide resolved
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
Expand Down
23 changes: 0 additions & 23 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ COPY requirements-cuda.txt requirements-cuda.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-cuda.txt

COPY requirements-mamba.txt requirements-mamba.txt
RUN python3 -m pip install packaging
RUN python3 -m pip install -r requirements-mamba.txt

# cuda arch list used by torch
# can be useful for both `dev` and `test`
Expand Down Expand Up @@ -127,22 +124,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-dev.txt

#################### DEV IMAGE ####################
#################### MAMBA Build IMAGE ####################
FROM dev as mamba-builder
# max jobs used for build
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}

WORKDIR /usr/src/mamba

COPY requirements-mamba.txt requirements-mamba.txt

# Download the wheel or build it if a pre-compiled release doesn't exist
RUN pip --verbose wheel -r requirements-mamba.txt \
--no-build-isolation --no-deps --no-cache-dir

#################### MAMBA Build IMAGE ####################

#################### vLLM installation IMAGE ####################
# image with vLLM installed
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base
Expand Down Expand Up @@ -179,10 +160,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install dist/*.whl --verbose

RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir

RUN --mount=type=cache,target=/root/.cache/pip \
. /etc/environment && \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
Expand Down
725 changes: 725 additions & 0 deletions csrc/mamba/causal_conv1d/causal_conv1d.cu

Large diffs are not rendered by default.

109 changes: 109 additions & 0 deletions csrc/mamba/causal_conv1d/causal_conv1d.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
#pragma once

#include <cuda_bf16.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////

struct ConvParamsBase {
using index_t = uint32_t;

int batch, dim, seqlen, width;
bool silu_activation;

index_t x_batch_stride;
index_t x_c_stride;
index_t x_l_stride;
index_t weight_c_stride;
index_t weight_width_stride;
index_t out_batch_stride;
index_t out_c_stride;
index_t out_l_stride;

index_t conv_state_batch_stride;
index_t conv_state_c_stride;
index_t conv_state_l_stride;

// Common data pointers.
void *__restrict__ x_ptr;
void *__restrict__ weight_ptr;
void *__restrict__ bias_ptr;
void *__restrict__ out_ptr;

void *__restrict__ conv_state_ptr;

void *__restrict__ seq_idx_ptr;
void *__restrict__ seq_pos_idx_ptr;

// No __restrict__ since initial_states could be the same as final_states.
void * initial_states_ptr;
index_t initial_states_batch_stride;
index_t initial_states_l_stride;
index_t initial_states_c_stride;

void * final_states_ptr;
index_t final_states_batch_stride;
index_t final_states_l_stride;
index_t final_states_c_stride;
};


////////////////////////////////////////////////////////////////////////////////////////////////////

template<int BYTES> struct BytesToType {};

template<> struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};

template<> struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};

template<> struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};

template<> struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};

template<> struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};

template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};

template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
28 changes: 28 additions & 0 deletions csrc/mamba/causal_conv1d/static_switch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h

#pragma once

/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
Loading
Loading