Skip to content

Commit

Permalink
Merge pull request #78 from YouJiacheng/better-kernel-binding-utilities
Browse files Browse the repository at this point in the history
Improve TK's custom kernel binding utilities
  • Loading branch information
benjaminfspector authored Jan 2, 2025
2 parents cdfce88 + da79300 commit de5cf82
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 50 deletions.
72 changes: 24 additions & 48 deletions include/pyutils/pyutils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ template<ducks::gl::all GL> struct from_object<GL> {
std::array<int, 4> shape = {1, 1, 1, 1};
auto py_shape = obj.attr("shape").cast<pybind11::tuple>();
size_t dims = py_shape.size();
for (size_t i = 0; i < dims && i < 4; ++i) {
if (dims > 4) {
throw std::runtime_error("Expected Tensor.ndim <= 4");
}
for (size_t i = 0; i < dims; ++i) {
shape[4 - dims + i] = pybind11::cast<int>(py_shape[i]);
}

Expand All @@ -45,54 +48,27 @@ template<ducks::gl::all GL> struct from_object<GL> {

template<typename T> concept has_dynamic_shared_memory = requires(T t) { { t.dynamic_shared_memory() } -> std::convertible_to<int>; };

#define EMPTY()
#define DEFER(m) m EMPTY EMPTY()()
#define EVAL(...) EVAL1024(__VA_ARGS__)
#define EVAL1024(...) EVAL512(EVAL512(__VA_ARGS__))
#define EVAL512(...) EVAL256(EVAL256(__VA_ARGS__))
#define EVAL256(...) EVAL128(EVAL128(__VA_ARGS__))
#define EVAL128(...) EVAL64(EVAL64(__VA_ARGS__))
#define EVAL64(...) EVAL32(EVAL32(__VA_ARGS__))
#define EVAL32(...) EVAL16(EVAL16(__VA_ARGS__))
#define EVAL16(...) EVAL8(EVAL8(__VA_ARGS__))
#define EVAL8(...) EVAL4(EVAL4(__VA_ARGS__))
#define EVAL4(...) EVAL2(EVAL2(__VA_ARGS__))
#define EVAL2(...) EVAL1(EVAL1(__VA_ARGS__))
#define EVAL1(...) __VA_ARGS__
#define MAP(m, struct_type, idx, first, ...) \
m(struct_type, idx, first) \
__VA_OPT__( \
, DEFER(_MAP)()(m, struct_type, (idx+1), __VA_ARGS__) \
)
#define _MAP() MAP
#define ITER(m, struct_type, ...) EVAL(MAP(m, struct_type, 0, __VA_ARGS__))

#define EXPAND_MEMBER_ACCESS(struct_type, idx, x) \
kittens::py::from_object<decltype(std::declval<struct_type>().x)>::make(args[idx])

#define GENERATE_CONSTRUCTOR(struct_type, ...) \
[](pybind11::args args) { \
return struct_type{ \
ITER(EXPAND_MEMBER_ACCESS, struct_type, __VA_ARGS__) \
}; \
}

#define BIND_KERNEL(module, name, kernel, globals_struct, ...) \
module.def(name, [](pybind11::args args) { \
globals_struct __g__{ ITER(EXPAND_MEMBER_ACCESS, globals_struct, __VA_ARGS__) }; \
if constexpr (kittens::py::has_dynamic_shared_memory<globals_struct>) { \
int __dynamic_shared_memory__ = (int)__g__.dynamic_shared_memory(); \
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, __dynamic_shared_memory__); \
kernel<<<__g__.grid(), __g__.block(), __dynamic_shared_memory__>>>(__g__); \
} else { \
kernel<<<__g__.grid(), __g__.block()>>>(__g__); \
} \
template<typename> struct trait;
template<typename MT, typename T> struct trait<MT T::*> { using member_type = MT; using type = T; };
template<typename> using object = pybind11::object;
template<auto kernel, typename TGlobal> static void bind_kernel(auto m, auto name, auto TGlobal::*... member_ptrs) {
m.def(name, [](object<decltype(member_ptrs)>... args) {
TGlobal __g__ {from_object<typename trait<decltype(member_ptrs)>::member_type>::make(args)...};
if constexpr (has_dynamic_shared_memory<TGlobal>) {
int __dynamic_shared_memory__ = (int)__g__.dynamic_shared_memory();
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, __dynamic_shared_memory__);
kernel<<<__g__.grid(), __g__.block(), __dynamic_shared_memory__>>>(__g__);
} else {
kernel<<<__g__.grid(), __g__.block()>>>(__g__);
}
});
#define BIND_FUNCTION(module, name, function, globals_struct, ...) \
module.def(name, [](pybind11::args args) { \
globals_struct __g__{ ITER(EXPAND_MEMBER_ACCESS, globals_struct, __VA_ARGS__) }; \
function(__g__); \
}
template<auto function, typename TGlobal> static void bind_function(auto m, auto name, auto TGlobal::*... member_ptrs) {
m.def(name, [](object<decltype(member_ptrs)>... args) {
TGlobal __g__ {from_object<typename trait<decltype(member_ptrs)>::member_type>::make(args)...};
function(__g__);
});
}

} // namespace py
} // namespace kittens
} // namespace kittens
4 changes: 2 additions & 2 deletions kernels/example_bind/example_bind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ void run_copy_kernel(globals g) {

PYBIND11_MODULE(example_bind, m) {
m.doc() = "example_bind python module";
BIND_KERNEL(m, "copy_kernel", copy_kernel, globals, in, out); // For wrapping kernels directly.
BIND_FUNCTION(m, "wrapped_copy_kernel", run_copy_kernel, globals, in, out); // For host functions that wrap the kernel.
py::bind_kernel<copy_kernel>(m, "copy_kernel", &globals::in, &globals::out);
py::bind_function<run_copy_kernel>(m, "wrapped_copy_kernel", &globals::in, &globals::out);
}

0 comments on commit de5cf82

Please sign in to comment.