From 4fc1f5599d61183d993c89f409fe70d1e395d13b Mon Sep 17 00:00:00 2001 From: Shuofei Zhu Date: Thu, 12 Sep 2024 17:05:08 -0700 Subject: [PATCH] refactor the c code template in driver.py --- third_party/nvidia/backend/cuda_launcher.c | 311 ++++++++++++++++++ .../nvidia/backend/{driver.c => cuda_util.c} | 0 third_party/nvidia/backend/driver.py | 285 ++-------------- 3 files changed, 346 insertions(+), 250 deletions(-) create mode 100644 third_party/nvidia/backend/cuda_launcher.c rename third_party/nvidia/backend/{driver.c => cuda_util.c} (100%) diff --git a/third_party/nvidia/backend/cuda_launcher.c b/third_party/nvidia/backend/cuda_launcher.c new file mode 100644 index 000000000000..887ea5db2f45 --- /dev/null +++ b/third_party/nvidia/backend/cuda_launcher.c @@ -0,0 +1,311 @@ +#include "cuda.h" +#include +#include +#include + +#define IS_EMPTY_HELPER(x) IS_EMPTY_HELPER_##x +#define IS_EMPTY_HELPER_ 1 +#define IS_EMPTY(x) IS_EMPTY_HELPER(x) + +// macros that should be filled in by driver.py: +// #define EXTRA_INNER_LAUNCH_PARAM_DECLS +// #define INNER_LAUNCH_CUDA_CHECK_ARGS +// #define LAUNCH_PY_ARGS +// #define PY_ARG_FORMAT_STR +// #define EXTRA_LAUNCH_PARSE_PY_ARGS +// #define DEVICE_PTR_INFO_VARS +// #define TMA_DESC_VARS +// #define EXTRA_INNER_LAUNCH_CALL_ARGS +// +// nomenclature: "EXTRA" means extra args appended to the end of the function call, which +// requires adding a comma to the end of the previous arg in driver.py. +// "INNER" means the inner function call of `_launch()`. +// + +// #ifndef PARAMS +// #error "PARAMS must be defined" +// #endif + +// #ifndef VAR_LIST_IN_LAUNCH +// #error "VAR_LIST_IN_LAUNCH must be defined" +// #endif + +// #ifndef PY_ARG_FORMAT_STR +// #error "PY_ARG_FORMAT_STR must be defined" +// #endif + +// #ifndef LAUNCH_PARSE_PY_ARGS +// #error "LAUNCH_PARSE_PY_ARGS must be defined" +// #endif + +// #ifndef DEVICE_PTR_INFO_VARS +// #error "DEVICE_PTR_INFO_VARS must be defined" +// #endif + +// #ifndef TMA_DESC_VARS +// #error "TMA_DESC_VARS must be defined" +// #endif + +// #ifndef EXTRA_INNER_LAUNCH_CALL_ARGS +// #error "EXTRA_INNER_LAUNCH_CALL_ARGS must be defined" +// #endif + +static inline void gpuAssert(CUresult code, const char *file, int line) +{ + if (code != CUDA_SUCCESS) + { + const char* prefix = "Triton Error [CUDA]: "; + const char* str; + cuGetErrorString(code, &str); + char err[1024] = {0}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + } +} + +#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); } + +typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra); + +static cuLaunchKernelEx_t getLaunchKernelExHandle() { + // Open the shared library + void* handle = dlopen("libcuda.so.1", RTLD_LAZY); + if (!handle) { + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); + return NULL; + } + // Clear any existing error + dlerror(); + cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx"); + // Check for errors + const char *dlsym_error = dlerror(); + if (dlsym_error) { + PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1"); + return NULL; + } + return cuLaunchKernelExHandle; +} + +// define a macro from driver.py to introduce extra args +// e.g.: +// arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) +// cuda_launcher_src += f"#define ARG_DECLS {arg_decls}".format(arg_decls=arg_decls) +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function EXTRA_INNER_LAUNCH_PARAM_DECLS) { + // define a macro from driver.py to introduce extra params + // e.g.: cuda_launcher_src += f"#define PARAMS {params}".format(params=', '.join(f"&arg{i}" for i in params)) + void *params[] = { + INNER_LAUNCH_CUDA_CHECK_ARGS + }; + if (gridX*gridY*gridZ > 0) { + if (num_ctas == 1) { + CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); + } else { + CUlaunchAttribute launchAttr[2]; + launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = clusterDimX; + launchAttr[0].value.clusterDim.y = clusterDimY; + launchAttr[0].value.clusterDim.z = clusterDimZ; + launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; + CUlaunchConfig config; + config.gridDimX = gridX * clusterDimX; + config.gridDimY = gridY * clusterDimY; + config.gridDimZ = gridZ * clusterDimZ; + config.blockDimX = 32 * num_warps; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared_memory; + config.hStream = stream; + config.attrs = launchAttr; + config.numAttrs = 2; + static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; + if (cuLaunchKernelExHandle == NULL) { + cuLaunchKernelExHandle = getLaunchKernelExHandle(); + } + CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); + } + } +} + +typedef struct _DevicePtrInfo { + CUdeviceptr dev_ptr; + bool valid; +} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) { + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) { + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); + return ptr_info; + } + if (obj == Py_None) { + // valid nullptr + return ptr_info; + } + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) { + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + } + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); + if(!ptr_info.dev_ptr) + return ptr_info; + uint64_t dev_ptr; + int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); + if (status == CUDA_ERROR_INVALID_VALUE) { + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); + ptr_info.valid = false; + } + ptr_info.dev_ptr = dev_ptr; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + } + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; +} + +static inline CUtensorMap* getTmaDesc(PyObject *obj) { + if (sizeof(CUtensorMap*) != 8) { + PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation"); + return NULL; + } + + PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_cpu_ptr"); + if (!method_handle) { + PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist"); + return NULL; + } + + PyObject *empty_tuple = PyTuple_New(0); + if (!empty_tuple) { + Py_DECREF(method_handle); + PyErr_SetString(PyExc_SystemError, "Internal Python error!"); + return NULL; + } + PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(method_handle); + if (!method_ret) { + PyErr_SetString(PyExc_SystemError, "Internal Python error!"); + return NULL; + } + + if (!PyLong_Check(method_ret)) { + PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() must return 64-bit int"); + Py_DECREF(method_ret); + return NULL; + } + + uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret); + Py_DECREF(method_ret); + if (!ptr_as_uint) { + PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_cpu_ptr()"); + return NULL; + } + if (ptr_as_uint % 64 != 0) { + PyErr_SetString(PyExc_ValueError, "tma_desc_cpu_ptr() must be 64-byte aligned"); + return NULL; + } + + return (CUtensorMap*)(ptr_as_uint); +} + +static PyObject* launch(PyObject* self, PyObject* args) { + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + // example python code to generate the arg list: + // ' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()]) + LAUNCH_PY_ARGS; + if(!PyArg_ParseTuple(args, + PY_ARG_FORMAT_STR + , &gridX, &gridY, &gridZ, &_stream, &_function, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook + EXTRA_LAUNCH_PARSE_PY_ARGS + )) { + return NULL; + } + + int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + if (!PyArg_ParseTuple(kernel_metadata, "iiiiii", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) { + PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); + return NULL; + } + + // extract launch metadata + if (launch_enter_hook != Py_None){ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + } + + // raise exception asap + // python string: "".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()]) + DEVICE_PTR_INFO_VARS; + // python string:"".join([f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" if ty == "nvTmaDesc" else "" for i, ty in signature.items()]) + TMA_DESC_VARS; + + Py_BEGIN_ALLOW_THREADS; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function + EXTRA_INNER_LAUNCH_CALL_ARGS + ); + Py_END_ALLOW_THREADS; + if (PyErr_Occurred()) { + return NULL; + } + if(launch_exit_hook != Py_None){ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + + } + + // return None + Py_INCREF(Py_None); + return Py_None; +} + +static PyMethodDef ModuleMethods[] = { + {"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = { + PyModuleDef_HEAD_INIT, + "__triton_launcher", + NULL, //documentation + -1, //size + ModuleMethods +}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) { + return NULL; + } + PyModule_AddFunctions(m, ModuleMethods); + return m; +} diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/cuda_util.c similarity index 100% rename from third_party/nvidia/backend/driver.c rename to third_party/nvidia/backend/cuda_util.c diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index bf1f066d5537..0b13ebe542e7 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -77,7 +77,7 @@ def __new__(cls): return cls.instance def __init__(self): - mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils") + mod = compile_module_from_src(Path(os.path.join(dirname, "cuda_util.c")).read_text(), "cuda_utils") self.load_binary = mod.load_binary self.get_device_properties = mod.get_device_properties self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters @@ -145,7 +145,7 @@ def format_of(ty): args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiKKOOOO" + args_format - args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + args_list = ', '.join(f"&_arg{i}" for i, ty in signature.items()) internal_args_list = [] for i, ty in signature.items(): @@ -159,254 +159,39 @@ def format_of(ty): # generate glue code params = [i for i in signature.keys() if i not in constants] - src = f""" -#include \"cuda.h\" -#include -#include -#include - -static inline void gpuAssert(CUresult code, const char *file, int line) -{{ - if (code != CUDA_SUCCESS) - {{ - const char* prefix = "Triton Error [CUDA]: "; - const char* str; - cuGetErrorString(code, &str); - char err[1024] = {{0}}; - strcat(err, prefix); - strcat(err, str); - PyGILState_STATE gil_state; - gil_state = PyGILState_Ensure(); - PyErr_SetString(PyExc_RuntimeError, err); - PyGILState_Release(gil_state); - }} -}} - -#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} - -typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra); - -static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ - // Open the shared library - void* handle = dlopen("libcuda.so.1", RTLD_LAZY); - if (!handle) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); - return NULL; - }} - // Clear any existing error - dlerror(); - cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx"); - // Check for errors - const char *dlsym_error = dlerror(); - if (dlsym_error) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1"); - return NULL; - }} - return cuLaunchKernelExHandle; -}} - -static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ - void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; - if (gridX*gridY*gridZ > 0) {{ - if (num_ctas == 1) {{ - CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); - }} else {{ - CUlaunchAttribute launchAttr[2]; - launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launchAttr[0].value.clusterDim.x = clusterDimX; - launchAttr[0].value.clusterDim.y = clusterDimY; - launchAttr[0].value.clusterDim.z = clusterDimZ; - launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; - launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; - CUlaunchConfig config; - config.gridDimX = gridX * clusterDimX; - config.gridDimY = gridY * clusterDimY; - config.gridDimZ = gridZ * clusterDimZ; - config.blockDimX = 32 * num_warps; - config.blockDimY = 1; - config.blockDimZ = 1; - config.sharedMemBytes = shared_memory; - config.hStream = stream; - config.attrs = launchAttr; - config.numAttrs = 2; - static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; - if (cuLaunchKernelExHandle == NULL) {{ - cuLaunchKernelExHandle = getLaunchKernelExHandle(); - }} - CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); - }} - }} -}} - -typedef struct _DevicePtrInfo {{ - CUdeviceptr dev_ptr; - bool valid; -}} DevicePtrInfo; - -static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); - return ptr_info; - }} - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - if(ptr){{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); - if(!ptr_info.dev_ptr) - return ptr_info; - uint64_t dev_ptr; - int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); - if (status == CUDA_ERROR_INVALID_VALUE) {{ - PyErr_Format(PyExc_ValueError, - "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); - ptr_info.valid = false; - }} - ptr_info.dev_ptr = dev_ptr; - Py_DECREF(ret); // Thanks ChatGPT! - return ptr_info; - }} - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - ptr_info.valid = false; - return ptr_info; -}} - -static inline CUtensorMap* getTmaDesc(PyObject *obj) {{ - if (sizeof(CUtensorMap*) != 8) {{ - PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation"); - return NULL; - }} - - PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_cpu_ptr"); - if (!method_handle) {{ - PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist"); - return NULL; - }} - - PyObject *empty_tuple = PyTuple_New(0); - if (!empty_tuple) {{ - Py_DECREF(method_handle); - PyErr_SetString(PyExc_SystemError, "Internal Python error!"); - return NULL; - }} - PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(method_handle); - if (!method_ret) {{ - PyErr_SetString(PyExc_SystemError, "Internal Python error!"); - return NULL; - }} - - if (!PyLong_Check(method_ret)) {{ - PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() must return 64-bit int"); - Py_DECREF(method_ret); - return NULL; - }} - - uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret); - Py_DECREF(method_ret); - if (!ptr_as_uint) {{ - PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_cpu_ptr()"); - return NULL; - }} - if (ptr_as_uint % 64 != 0) {{ - PyErr_SetString(PyExc_ValueError, "tma_desc_cpu_ptr() must be 64-byte aligned"); - return NULL; - }} - - return (CUtensorMap*)(ptr_as_uint); -}} - -static PyObject* launch(PyObject* self, PyObject* args) {{ - int gridX, gridY, gridZ; - uint64_t _stream; - uint64_t _function; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *kernel_metadata = NULL; - PyObject *launch_metadata = NULL; - {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function, - &kernel_metadata, &launch_metadata, - &launch_enter_hook, &launch_exit_hook {args_list})) {{ - return NULL; - }} - - int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; - if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ - PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); - return NULL; - }} - - // extract launch metadata - if (launch_enter_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_enter_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - }} - - // raise exception asap - {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - {"".join([f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" if ty == "nvTmaDesc" else "" for i, ty in signature.items()])}; - Py_BEGIN_ALLOW_THREADS; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); - Py_END_ALLOW_THREADS; - if (PyErr_Occurred()) {{ - return NULL; - }} - - if(launch_exit_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_exit_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - - }} - - // return None - Py_INCREF(Py_None); - return Py_None; -}} - -static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel -}}; - -static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_launcher\", - NULL, //documentation - -1, //size - ModuleMethods -}}; - -PyMODINIT_FUNC PyInit___triton_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; -}} -""" + + def gen_c_def_macro(macro_name, macro_value): + return f"#define {macro_name} {macro_value}\n" + + # macros to define: + """ + #define EXTRA_INNER_LAUNCH_PARAM_DECLS + #define INNER_LAUNCH_CUDA_CHECK_ARGS + #define LAUNCH_PY_ARGS + #define PY_ARG_FORMAT_STR + #define EXTRA_LAUNCH_PARSE_PY_ARGS + #define DEVICE_PTR_INFO_VARS + #define TMA_DESC_VARS + #define EXTRA_INNER_LAUNCH_CALL_ARGS + """ + macro_defs = gen_c_def_macro("EXTRA_INNER_LAUNCH_PARAM_DECLS", ", " + arg_decls if arg_decls else "") + macro_defs += gen_c_def_macro("INNER_LAUNCH_CUDA_CHECK_ARGS", ', '.join(f"&arg{i}" for i in params)) + macro_defs += gen_c_def_macro("LAUNCH_PY_ARGS", ';'.join([f"{_extracted_type(ty)} _arg{i}" for i, ty in signature.items()])) + macro_defs += gen_c_def_macro("PY_ARG_FORMAT_STR", f'"{format}"') + macro_defs += gen_c_def_macro("EXTRA_LAUNCH_PARSE_PY_ARGS", ", " + args_list if args_list else "") + device_ptr_info_var_list = [] + tma_desc_var_list = [] + for i, ty in signature.items(): + if ty[0] == "*": + device_ptr_info_var_list.append(f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;") + elif ty == "nvTmaDesc": + tma_desc_var_list.append(f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;") + + macro_defs += gen_c_def_macro("DEVICE_PTR_INFO_VARS", " \\\n".join(device_ptr_info_var_list)) + macro_defs += gen_c_def_macro("TMA_DESC_VARS", " \\\n".join(tma_desc_var_list)) + extra_inner_launch_call_args = ', '.join(internal_args_list) + macro_defs += gen_c_def_macro("EXTRA_INNER_LAUNCH_CALL_ARGS", ', ' + extra_inner_launch_call_args if extra_inner_launch_call_args else "") + src = macro_defs + Path(os.path.join(dirname, "cuda_launcher.c")).read_text() return src