Skip to content

Commit

Permalink
refactor: refactor bmm_fp8.cuh (#517)
Browse files Browse the repository at this point in the history
This PR made refactors `bmm_fp8.cuh` to fix the following behaviors:
1. We should not instantiate templates in header files, if there are two
source files includes this header, there will be symbol conflict. So
it's preferable to instantiate templates in source files.
2. We don't want to depend on torch source code in the headers as torch
api's might subject to changes and other backends do not want to rely on
a specific version of torch. We should make the dependency of header
files as simple as possible (only cuda and libc++).

cc @zhyncs
  • Loading branch information
yzh119 authored Oct 9, 2024
1 parent 85b1878 commit 6e18209
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 51 deletions.
95 changes: 47 additions & 48 deletions include/flashinfer/gemm/bmm_fp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,40 @@
#ifndef FLASHINFER_GEMM_BMM_FP8_CUH_
#define FLASHINFER_GEMM_BMM_FP8_CUH_

// NOTE(Zihao): we should leave pytorch related includes outside of the header files.
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <cublasLt.h>
#include <cuda_fp8.h>
#include <torch/extension.h>

#include <stdexcept>
#include <type_traits>

#define FLASHINFER_CUBLAS_CHECK(EXPR) \
{ \
cublasStatus_t e = (EXPR); \
if (e != CUBLAS_STATUS_SUCCESS) { \
throw std::runtime_error("CUBLAS Error: " + std::string(cublasGetStatusString(e))); \
} \
}

#ifndef NDEBUG
#define FLASHINFER_CUBLAS_CALL(EXPR) \
{ \
cublasStatus_t e = (EXPR); \
if (e != CUBLAS_STATUS_SUCCESS) { \
std::cerr << "CUBLAS Error: " << cublasGetStatusString(e) << " (" << e << ") " << __FILE__ \
<< ": line " << __LINE__ << " at function " << #EXPR << std::endl; \
return e; \
} \
}
#else
#define FLASHINFER_CUBLAS_CALL(EXPR) \
{ \
cudaError_t e = (EXPR); \
if (e != CUBLAS_STATUS_SUCCESS) { \
return e; \
} \
}
#endif

namespace flashinfer {

namespace bmm_fp8 {
Expand All @@ -34,7 +58,7 @@ template <typename T, cublasStatus_t (*destructor)(T*)>
struct CuBlasLtDeleter {
void operator()(T* x) {
if (x != nullptr) {
TORCH_CUDABLAS_CHECK(destructor(x));
FLASHINFER_CUBLAS_CHECK(destructor(x));
}
}
};
Expand All @@ -54,12 +78,13 @@ class CuBlasLtMatmulDescriptor
public:
CuBlasLtMatmulDescriptor(cublasComputeType_t compute_type, cudaDataType_t scale_type) {
cublasLtMatmulDesc_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
FLASHINFER_CUBLAS_CHECK(cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
FLASHINFER_CUBLAS_CHECK(
::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

Expand All @@ -69,13 +94,14 @@ class CuBlasLtMatrixLayout
CuBlasLtMatrixLayout(cudaDataType_t type, uint64_t rows, uint64_t cols, int64_t ld,
bool t = false) {
cublasLtMatrixLayout_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
FLASHINFER_CUBLAS_CHECK(
cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatrixLayoutAttribute_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T)));
FLASHINFER_CUBLAS_CHECK(
::cublasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

Expand All @@ -84,12 +110,12 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<cublasLtMatmulPrefere
public:
CuBlasLtMatmulPreference() {
cublasLtMatmulPreference_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor));
FLASHINFER_CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatmulPreferenceAttributes_t attr, const T value) {
TORCH_CUDABLAS_CHECK(
FLASHINFER_CUBLAS_CHECK(
::cublasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};
Expand All @@ -110,8 +136,10 @@ cudaDataType_t get_cuda_data_type() {
}

template <typename AT, typename BT, typename DT>
void bmm_fp8_internal_cublaslt(const AT* A, const BT* B, DT* D, int batch_size, int m, int n, int k,
const float* A_scale, const float* B_scale) {
cublasStatus_t bmm_fp8_internal_cublaslt(void* workspace, size_t workspace_size_in_bytes,
const AT* A, const BT* B, DT* D, int batch_size, int m,
int n, int k, const float* A_scale, const float* B_scale,
cublasLtHandle_t lt_handle, cudaStream_t stream) {
const void* A_scale_ptr = static_cast<const void*>(A_scale);
const void* B_scale_ptr = static_cast<const void*>(B_scale);
auto matmul_desp = CuBlasLtMatmulDescriptor(CUBLAS_COMPUTE_32F, CUDA_R_32F);
Expand Down Expand Up @@ -147,55 +175,26 @@ void bmm_fp8_internal_cublaslt(const AT* A, const BT* B, DT* D, int batch_size,
}

CuBlasLtMatmulPreference preference;
size_t workspace_size = 1024 * 1024; // 1 MiB
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size);
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto workspace = allocator.allocate(workspace_size);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size_in_bytes);
cublasLtMatmulHeuristicResult_t heuristic_result = {};
int returned_result = 0;
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
FLASHINFER_CUBLAS_CALL(cublasLtMatmulAlgoGetHeuristic(
lt_handle, matmul_desp.descriptor(), a_desp.descriptor(), b_desp.descriptor(),
d_desp.descriptor(), d_desp.descriptor(), preference.descriptor(), 1, &heuristic_result,
&returned_result));
if (returned_result == 0) {
TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
return CUBLAS_STATUS_NOT_SUPPORTED;
}

const float alpha = 1.0f;
const float beta = 0.0f;
cublasStatus_t status = cublasLtMatmul(
FLASHINFER_CUBLAS_CALL(cublasLtMatmul(
lt_handle, matmul_desp.descriptor(), &alpha, A, a_desp.descriptor(), B, b_desp.descriptor(),
&beta, nullptr, d_desp.descriptor(), D, d_desp.descriptor(), &heuristic_result.algo,
workspace.mutable_get(), workspace_size, at::cuda::getCurrentCUDAStream());
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, at::cuda::blas::_cublasGetErrorEnum(status));
workspace, workspace_size_in_bytes, stream));
return CUBLAS_STATUS_SUCCESS;
}

// NOTE(Zihao): templates should not be initialized in the header files!
template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>(
const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n,
int k, const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, half>(
const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, half* D, int batch_size, int m, int n, int k,
const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, __nv_bfloat16>(
const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, __nv_bfloat16* D, int batch_size, int m, int n,
int k, const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, half>(
const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, half* D, int batch_size, int m, int n, int k,
const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, __nv_bfloat16>(
const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n,
int k, const float* A_scale, const float* B_scale);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, half>(
const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, half* D, int batch_size, int m, int n, int k,
const float* A_scale, const float* B_scale);

} // namespace bmm_fp8
} // namespace flashinfer

Expand Down
52 changes: 49 additions & 3 deletions python/csrc/bmm_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,41 @@

#include "pytorch_extension_utils.h"

using namespace flashinfer;
namespace flashinfer {
namespace bmm_fp8 {

template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>(
void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B,
__nv_bfloat16* D, int batch_size, int m, int n, int k, const float* A_scale,
const float* B_scale, cublasLtHandle_t lt_handle, cudaStream_t stream);

template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, half>(
void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B,
half* D, int batch_size, int m, int n, int k, const float* A_scale, const float* B_scale,
cublasLtHandle_t lt_handle, cudaStream_t stream);

template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, __nv_bfloat16>(
void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B,
__nv_bfloat16* D, int batch_size, int m, int n, int k, const float* A_scale,
const float* B_scale, cublasLtHandle_t lt_handle, cudaStream_t stream);

template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, half>(
void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B,
half* D, int batch_size, int m, int n, int k, const float* A_scale, const float* B_scale,
cublasLtHandle_t lt_handle, cudaStream_t stream);

template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, __nv_bfloat16>(
void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B,
__nv_bfloat16* D, int batch_size, int m, int n, int k, const float* A_scale,
const float* B_scale, cublasLtHandle_t lt_handle, cudaStream_t stream);

template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, half>(
void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B,
half* D, int batch_size, int m, int n, int k, const float* A_scale, const float* B_scale,
cublasLtHandle_t lt_handle, cudaStream_t stream);

} // namespace bmm_fp8
} // namespace flashinfer

void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D,
torch::Tensor& A_scale, torch::Tensor& B_scale) {
Expand Down Expand Up @@ -50,16 +84,28 @@ void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D,
auto k = A.size(2);
auto n = B.size(2);

// Per the cublas documentation, the recommended workspace buffer size for hopper is 32MB.
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
// create an empty buffer of 32MB, with data type uint8 and on the same device as A
auto workspace_buffer = torch::empty(
{32 * 1024 * 1024}, torch::TensorOptions().dtype(torch::kUInt8).device(A.device()));
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
auto stream = at::cuda::getCurrentCUDAStream();

// PyTorch is row major by default. cuBLASLt is column major by default.
// We need row major D as expected.
// A ^ T * B = D, so D ^ T = B ^ T * A
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(B.scalar_type(), b_type, [&] {
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(A.scalar_type(), a_type, [&] {
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(D.scalar_type(), d_type, [&] {
flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(
auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(
workspace_buffer.data_ptr(), workspace_buffer.numel(),
static_cast<b_type*>(B.data_ptr()), static_cast<a_type*>(A.data_ptr()),
static_cast<d_type*>(D.data_ptr()), batch_size, n, m, k,
static_cast<float*>(B_scale.data_ptr()), static_cast<float*>(A_scale.data_ptr()));
static_cast<float*>(B_scale.data_ptr()), static_cast<float*>(A_scale.data_ptr()),
lt_handle, stream);
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "bmm_fp8_internal_cublaslt failed: ",
cublasGetStatusString(status));
return true;
});
});
Expand Down

0 comments on commit 6e18209

Please sign in to comment.