Skip to content

Commit

Permalink
refactor (csrc): Restructure C++ code organization to facilitate addi…
Browse files Browse the repository at this point in the history
…ng new kernels (#169)

This PR does not change any logic or functionality in the main branch.
It restructures the existing C++ code organization to facilitate the
addition of new kernels in subsequent PRs.

In general, each customized CUDA kernel involves modifying the following
three files at the backend:

1. Registering Python bindings and adding declarations in `ops.cc`.
1. Placing the host's kernel launch function in `x.cc`, where `x` is the
operator's name.
1. Implementing CUDA kernels in `x.cuh`.
  • Loading branch information
lcy-seso authored Jan 20, 2025
1 parent 86685d5 commit 676964f
Show file tree
Hide file tree
Showing 12 changed files with 976 additions and 909 deletions.
37 changes: 26 additions & 11 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,39 @@ UseTab: Never
IndentWidth: 2
ColumnLimit: 80

AccessModifierOffset: -2

# Force pointers to the type for C++.
DerivePointerAlignment: false
PointerAlignment: Left

# Reordering #include statements can (and currently will) introduce errors
SortIncludes: false

# Style choices
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
IndentPPDirectives: BeforeHash

SortIncludes: true
IncludeBlocks: Regroup
IncludeCategories:
- Regex: '^<'
Priority: 4
- Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
Priority: 3
- Regex: '^"(qoda|\.\.)/'
Priority: 2
- Regex: '.*'
Priority: 1
- Regex: '<([A-Za-z0-9\Q/-_\E])+>'
Priority: 4
- Regex: '<(catch2|boost)\/'
Priority: 3
- Regex: '<([A-Za-z0-9.\Q/-_\E])+>'
Priority: 2
- Regex: '"([A-Za-z0-9.\Q/-_\E])+"'
Priority: 1

# If true, empty lines at the start of blocks are kept.
KeepEmptyLinesAtTheStartOfBlocks: false
AllowShortLoopsOnASingleLine: true
AllowShortIfStatementsOnASingleLine: true
Cpp11BracedListStyle: true
# If true, always break after the template<...> of a template declaration.
AlwaysBreakTemplateDeclarations: true
# If false, a function declaration's or function definition's parameters will
# either all be on the same line or will have one line each.
BinPackArguments: true
BreakConstructorInitializersBeforeComma: true
# The maximum number of consecutive empty lines to keep.
MaxEmptyLinesToKeep: 1
7 changes: 6 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
{
"yapf.args":["--style={based_on_s'tyle: google, column_limit: 80, indent_width: 4}"]
"yapf.args": [
"--style={based_on_s'tyle: google, column_limit: 80, indent_width: 4}"
],
"files.associations": {
"optional": "cpp"
}
}
17 changes: 14 additions & 3 deletions csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,26 @@
#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

namespace vptq {

#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)

#define gpuErrchk(ret) gpuAssert((ret), __FILE__, __LINE__);

class OptionalCUDAGuard {
int set_device_ = -1;
int current_device_ = -1;

public:
public:
OptionalCUDAGuard(int device) : set_device_(device) {
cudaError_t err = cudaGetDevice(&current_device_);
std::stringstream ss;
Expand All @@ -32,13 +44,12 @@ class OptionalCUDAGuard {
}
};

#define gpuErrchk(ret) gpuAssert((ret), __FILE__, __LINE__);

inline void gpuAssert(cudaError_t code, const char* file, int line) {
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
line);
TORCH_CHECK(false, cudaGetErrorString(code));
}
}

} // namespace vptq
25 changes: 23 additions & 2 deletions csrc/utils.cuh → csrc/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,47 @@
// Licensed under the MIT License.
#pragma once

#include <ATen/cuda/CUDAContext.h>

#if defined(USE_ROCM)
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>

#define VPTQ_LDG(arg) __ldg(arg)
#define SHFL_DOWN(val, offset) __shfl_down(val, offset)
#define WARP_SIZE warpSize

typedef __hip_bfloat162 __bfloat162;
typedef __hip_bfloat16 __bfloat16;
#else
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>

#define WARP_SIZE 32
#define VPTQ_LDG(arg) *(arg)
#define SHFL_DOWN(val, offset) __shfl_down_sync(0xffffffff, val, offset)

typedef __nv_bfloat162 __bfloat162;
typedef __nv_bfloat16 __bfloat16;
#endif

namespace vptq {

template <typename T>
struct C10ToNvType {
typedef __bfloat16 type;
};

template <>
struct C10ToNvType<c10::Half> {
typedef __half type;
};

template <>
struct C10ToNvType<float> {
typedef float type;
};

namespace cuda {

constexpr int kBlockSize = 256;
Expand Down Expand Up @@ -243,4 +263,5 @@ __device__ __half operator*(const __half& a, const __half& b) {
return __hmul(a, b);
}
#endif

} // namespace vptq
Loading

0 comments on commit 676964f

Please sign in to comment.