Skip to content

Commit

Permalink
Moved Pallas GPU lowering registartion code into a separate submodule
Browse files Browse the repository at this point in the history
This makes the layout similar to the one we use in Pallas TPU.

PiperOrigin-RevId: 609324956
  • Loading branch information
superbobry authored and The jax_triton Authors committed Feb 26, 2024
1 parent 708d3e8 commit 8c9c302
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def compile_ttir_to_ptx_inplace(
cuda_options: cb.CUDAOptions,
device: int = 0,
device_type: str = "cuda",
) -> Tuple[str, str, int, int]:
) -> Tuple[str, str, int, int, Any]:
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
if cuda_options.debug:
print(ttir)
Expand Down

0 comments on commit 8c9c302

Please sign in to comment.