-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor the C code template in third_party/nvidia/backend/driver.py
#4722
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
def gen_c_def_macro(macro_name, macro_value): | ||
return f"#define {macro_name} {macro_value}\n" | ||
|
||
# macros to define: | ||
""" | ||
#define EXTRA_INNER_LAUNCH_PARAM_DECLS | ||
#define INNER_LAUNCH_CUDA_CHECK_ARGS | ||
#define LAUNCH_PY_ARGS | ||
#define PY_ARG_FORMAT_STR | ||
#define EXTRA_LAUNCH_PARSE_PY_ARGS | ||
#define DEVICE_PTR_INFO_VARS | ||
#define TMA_DESC_VARS | ||
#define EXTRA_INNER_LAUNCH_CALL_ARGS | ||
""" | ||
macro_defs = gen_c_def_macro("EXTRA_INNER_LAUNCH_PARAM_DECLS", ", " + arg_decls if arg_decls else "") | ||
macro_defs += gen_c_def_macro("INNER_LAUNCH_CUDA_CHECK_ARGS", ', '.join(f"&arg{i}" for i in params)) | ||
macro_defs += gen_c_def_macro("LAUNCH_PY_ARGS", | ||
';'.join([f"{_extracted_type(ty)} _arg{i}" for i, ty in signature.items()])) | ||
macro_defs += gen_c_def_macro("PY_ARG_FORMAT_STR", f'"{format}"') | ||
macro_defs += gen_c_def_macro("EXTRA_LAUNCH_PARSE_PY_ARGS", ", " + args_list if args_list else "") | ||
device_ptr_info_var_list = [] | ||
tma_desc_var_list = [] | ||
for i, ty in signature.items(): | ||
if ty[0] == "*": | ||
device_ptr_info_var_list.append( | ||
f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;") | ||
elif ty == "nvTmaDesc": | ||
tma_desc_var_list.append(f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;") | ||
|
||
macro_defs += gen_c_def_macro("DEVICE_PTR_INFO_VARS", " \\\n".join(device_ptr_info_var_list)) | ||
macro_defs += gen_c_def_macro("TMA_DESC_VARS", " \\\n".join(tma_desc_var_list)) | ||
extra_inner_launch_call_args = ', '.join(internal_args_list) | ||
macro_defs += gen_c_def_macro("EXTRA_INNER_LAUNCH_CALL_ARGS", | ||
', ' + extra_inner_launch_call_args if extra_inner_launch_call_args else "") | ||
src = macro_defs + Path(os.path.join(dirname, "cuda_launcher.c")).read_text() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
personally I find this harder to read/modify than the previous version.
Can we only move the functions that don't need patching in a separate file and keep the rest in a template kind of way (and maybe break up functions if it helps moving more code out of python)
Fly-by comment here, since this seems quite related to what we did about the launcher internally at Google: I don't think we need python string interpolation at all, and we could completely precompile a generic version of the launcher. @ThomasRaoux would this be something interesting for you? If so, I could work on that (or if @sfzhu93 wants to take that on, I'm super happy to have that support!). I wanted to try upstreaming this, but didn't get to it yet. The best I have right now is this patch I can share (don't have a machine with command line git at the moment to make a proper PR out of it 😅): https://gist.github.com/gflegar/6fea5e50023a69e17c64e73eb0037cab. It should hopefully
(Sorry for not making this more polished, lots of stuff going on right now, I can definitely polish it more in a few weeks time, but just wanted to bring this option up since stuff seems to be moving on this.) |
Interesting, did you measure the dispatch cost of doing that? |
@gflegar @ThomasRaoux are you happy to temporarily move forward with my approach? I will remove the macros and wrap the code that will be generated by Python into separate functions. |
closing this for now, if we have a strong motivation we could discuss going with the pre-compiled launcher |
Oh, I completely missed this, since I was out for a while and got a huge email backlog. I didn't measure it yet, since I can't run the upstream version in our internal environment (we don't have a C compiler available at runtime, hence we patches Triton to precompile it), and can't run our patched version in OSS since I don't have CMake files for it yet. So getting an apples-to-apples comparison is somewhat difficult at this point. Will definitely run those if I try to upstream it. But I don't think it should be that bad, since I was specifically trying to implement it in a way that minimizes launch overhead - there's no memory allocations, and hopefully only inexpensive Python API calls needed to launch the kernel. The expensive part of figuring out and caching the signature is done at kernel compile time. Do you actually have a specific set of utilities / methodology you use to evaluate the dispatch cost you could point me to, so I use the same thing as you do when I attempt to evaluate it @ThomasRaoux? |
Refactor the C code template in
driver.py
Previously, in the
third_party/nvidia/backend/driver.py
, there was a long format string to define a C source code. In this commit, I moved the long format string in Python code into a separate C file, and use macros on thedriver.py
side to fill in the missing part in the C code.This improves readability and makes it easy for future extension to the
driver.py
.I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.This PR does not need a test because this PR doesn't introduce any new features or bug fixing.
I have not added any
lit
tests.