Skip to content

Commit

Permalink
refactor: added static switch for marlin kernel dispatch (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi authored Sep 3, 2024
1 parent bb739f6 commit 2a947e1
Show file tree
Hide file tree
Showing 12 changed files with 332 additions and 405 deletions.
2 changes: 2 additions & 0 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ Checks: >
-modernize-pass-by-value,
-modernize-avoid-c-arrays,
-misc-non-private-member-variables-in-classes,
-llvm-else-after-return,
-performance-unnecessary-value-param,
-bugprone-easily-swappable-parameters,
-bugprone-unused-return-value,
-bugprone-macro-parentheses,
-readability-identifier-length,
-readability-use-anyofallof,
-readability-magic-numbers,
-readability-else-after-return,
-cppcoreguidelines-avoid-magic-numbers,
-cppcoreguidelines-avoid-non-const-global-variables,
-cppcoreguidelines-special-member-functions,
Expand Down
2 changes: 1 addition & 1 deletion cmake/FindCUDADriver.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ find_package_handle_standard_args(CUDADriver
)

if(CUDADriver_FOUND)
message(STATUS "Found CUDA driver : ${CUDADriver_LIBRARY}")
message(STATUS "Found CUDADriver : ${CUDADriver_LIBRARY}")
if(NOT TARGET CUDA::driver)
add_library(CUDA::driver SHARED IMPORTED)
set_target_properties(CUDA::driver PROPERTIES
Expand Down
2 changes: 1 addition & 1 deletion cmake/FindNCCL.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ find_package_handle_standard_args(NCCL
VERSION_VAR NCCL_VERSION)

if(NCCL_FOUND)
message(STATUS "Found NCCL ${NCCL_VERSION} (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
message(STATUS "Found NCCL ${NCCL_VERSION} (include: ${NCCL_INCLUDE_DIR}, library: ${NCCL_LIBRARY})")
if(NOT TARGET NCCL::nccl)
add_library(NCCL::nccl UNKNOWN IMPORTED)
set_target_properties(NCCL::nccl PROPERTIES
Expand Down
21 changes: 5 additions & 16 deletions scalellm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ class OutputAsyncStream:
"""A stream of RequestOutput objects, which can be used to
send responses to the client asynchronously."""

def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
def __init__(self) -> None:
# asyncio.Queue is used to store the items in the stream, not thread-safe
self._queue = asyncio.Queue()
# event loop used to schedule callbacks from other threads
self._loop = loop
self._loop = asyncio.get_running_loop()
assert self._loop is not None
self._cancelled = False

def _put_nowait(self, item):
Expand Down Expand Up @@ -170,14 +171,6 @@ def __init__(
options.num_handling_threads = num_handling_threads
# create the LLM handler
self._handler = LLMHandler(options)
# event loop for async stream callbacks
self._loop = None

def _ensure_event_loop(self):
# get running event loop if not set
if self._loop is None:
self._loop = asyncio.get_running_loop()
assert self._loop is not None, "No event loop found"

# schedule a request to the engine, and return a stream to receive output
async def schedule_async(
Expand All @@ -187,9 +180,7 @@ async def schedule_async(
priority: Priority = Priority.NORMAL,
stream: bool = False,
) -> OutputAsyncStream:
self._ensure_event_loop()

output_stream = OutputAsyncStream(self._loop)
output_stream = OutputAsyncStream()

def callback(output: RequestOutput) -> bool:
output.prompt = prompt
Expand All @@ -209,9 +200,7 @@ async def schedule_chat_async(
priority: Priority = Priority.NORMAL,
stream: bool = False,
) -> OutputAsyncStream:
self._ensure_event_loop()

output_stream = OutputAsyncStream(self._loop)
output_stream = OutputAsyncStream()

def callback(output: RequestOutput) -> bool:
return output_stream.put(output)
Expand Down
2 changes: 0 additions & 2 deletions src/kernels/quantization/marlin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/generated/*.cu")
cc_library(
NAME
marlin.kernels
HDRS
gemm_kernel_launch.cuh
SRCS
${GENERATED_SRC_FILES}
fp16_int4_gemm.cu
Expand Down
25 changes: 8 additions & 17 deletions src/kernels/quantization/marlin/awq_repack.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "common.h"
#include "memory.h"
#include "static_switch.h"

namespace marlin {
namespace {
Expand Down Expand Up @@ -185,15 +186,6 @@ __global__ void awq_marlin_repack_kernel(
}
}

#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
auto kernel = &awq_marlin_repack_kernel<repack_threads, NUM_BITS>; \
cudaFuncSetAttribute( \
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
kernel<<<blocks, repack_threads, max_shared_mem, stream>>>( \
q_weight_ptr, out_ptr, size_k, size_n); \
}

} // namespace

void awq_repack(const torch::Tensor& q_weight, // (k, n/pack_factor)
Expand Down Expand Up @@ -230,14 +222,13 @@ void awq_repack(const torch::Tensor& q_weight, // (k, n/pack_factor)
&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);

// NOLINTNEXTLINE
if (false) {
}
CALL_IF(4)
CALL_IF(8)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
}
NUM_BITS_SWITCH(num_bits, [&] {
auto kernel = &awq_marlin_repack_kernel<repack_threads, NUM_BITS>;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);
kernel<<<blocks, repack_threads, max_shared_mem, stream>>>(
q_weight_ptr, out_ptr, size_k, size_n);
});
}

} // namespace marlin
141 changes: 56 additions & 85 deletions src/kernels/quantization/marlin/fp8_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "mma.h"
#include "numeric_conversion.h"
#include "scale_type.h"
#include "static_switch.h"

namespace marlin {
namespace {
Expand Down Expand Up @@ -726,31 +727,6 @@ __global__ void Marlin(
}
}

#define __CALL_IF( \
NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && num_threads == NUM_THREADS) { \
auto kernel = &Marlin<scalar_t, \
NUM_BITS, \
NUM_THREADS, \
THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, \
pipe_stages>; \
cudaFuncSetAttribute( \
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
kernel<<<blocks, NUM_THREADS, max_shared_mem, stream>>>(A_ptr, \
B_ptr, \
C_ptr, \
s_ptr, \
num_groups, \
prob_m, \
prob_n, \
prob_k, \
locks); \
}

using thread_config_t = struct {
int thread_k;
int thread_n;
Expand Down Expand Up @@ -929,11 +905,18 @@ exec_config_t determine_thread_config(int prob_m,
return exec_config_t{0, {-1, -1, -1}};
}

#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, NUM_THREADS)
#define NK_BLOCKS_THREADS_SWITCH( \
thread_n_blocks, thread_k_blocks, num_threads, ...) \
[&] { \
DISPATCH_NK_BLOCKS_THREADS(32, 2, 256, __VA_ARGS__); \
DISPATCH_NK_BLOCKS_THREADS(16, 4, 256, __VA_ARGS__); \
DISPATCH_NK_BLOCKS_THREADS(8, 8, 256, __VA_ARGS__); \
DISPATCH_NK_BLOCKS_THREADS(8, 4, 128, __VA_ARGS__); \
DISPATCH_NK_BLOCKS_THREADS(4, 8, 128, __VA_ARGS__); \
LOG(FATAL) << "Unsupported (N_BLOCKS, K_BLOCKS, NUM_THREADS): " \
<< thread_n_blocks << ", " << thread_k_blocks << ", " \
<< num_threads; \
}()

template <typename scalar_t>
void marlin_mm_f16i4(const void* A,
Expand Down Expand Up @@ -1038,24 +1021,31 @@ void marlin_mm_f16i4(const void* A,
thread_m_blocks = exec_cfg.max_m_blocks;
}

// Define kernel configurations
// NOLINTNEXTLINE
if (false) {
}
CALL_IF(8, 32, 2, 256)
CALL_IF(8, 16, 4, 256)
CALL_IF(8, 8, 8, 256)
CALL_IF(8, 8, 4, 128)
CALL_IF(8, 4, 8, 128)
else {
LOG(FATAL) << "Unsupported shapes: MNK = [" << prob_m << ", " << prob_n
<< ", " << prob_k << "]"
<< ", num_groups = " << num_groups
<< ", group_size = " << group_size
<< ", thread_m_blocks = " << thread_m_blocks
<< ", thread_n_blocks = " << thread_n_blocks
<< ", thread_k_blocks = " << thread_k_blocks;
}
M_BLOCKS_SWITCH(thread_m_blocks, [&] {
NK_BLOCKS_THREADS_SWITCH(
thread_n_blocks, thread_k_blocks, num_threads, [&] {
constexpr static int NUM_BITS = 8;
auto kernel = &Marlin<scalar_t,
NUM_BITS,
NUM_THREADS,
THREAD_M_BLOCKS,
THREAD_N_BLOCKS,
THREAD_K_BLOCKS,
pipe_stages>;
cudaFuncSetAttribute(kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
kernel<<<blocks, NUM_THREADS, max_shared_mem, stream>>>(A_ptr,
B_ptr,
C_ptr,
s_ptr,
num_groups,
prob_m,
prob_n,
prob_k,
locks);
});
});

A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
Expand Down Expand Up @@ -1098,43 +1088,24 @@ void fp8_gemm(const torch::Tensor& A, // (m, k)
<< " is below min_workspace_size = " << min_workspace_size;

int dev = A.get_device();
if (A.scalar_type() == at::ScalarType::Half) {
marlin_mm_f16i4<half>(A.data_ptr(),
B.data_ptr(),
C.data_ptr(),
scales.data_ptr(),
prob_m,
prob_n,
prob_k,
workspace.data_ptr(),
num_groups,
group_size,
dev,
at::cuda::getCurrentCUDAStream(dev),
thread_k,
thread_n,
sms,
marlin::max_par);
} else if (A.scalar_type() == at::ScalarType::BFloat16) {
marlin_mm_f16i4<nv_bfloat16>(A.data_ptr(),
B.data_ptr(),
C.data_ptr(),
scales.data_ptr(),
prob_m,
prob_n,
prob_k,
workspace.data_ptr(),
num_groups,
group_size,
dev,
at::cuda::getCurrentCUDAStream(dev),
thread_k,
thread_n,
sms,
marlin::max_par);
} else {
LOG(FATAL) << "Unsupported data type: " << A.scalar_type();
}
FLOAT_TYPE_SWITCH(A.scalar_type(), [&] {
marlin_mm_f16i4<scalar_t>(A.data_ptr(),
B.data_ptr(),
C.data_ptr(),
scales.data_ptr(),
prob_m,
prob_n,
prob_k,
workspace.data_ptr(),
num_groups,
group_size,
dev,
at::cuda::getCurrentCUDAStream(dev),
thread_k,
thread_n,
sms,
marlin::max_par);
});
}

} // namespace marlin
Loading

0 comments on commit 2a947e1

Please sign in to comment.