Skip to content

Commit

Permalink
Integration Fixes on top of Triton import
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588045313
  • Loading branch information
Moerafaat authored and The jax_triton Authors committed Dec 7, 2023
1 parent 68c6262 commit 71fc158
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 65 deletions.
173 changes: 111 additions & 62 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Module for calling Triton kernels from JAX."""

# b/301982023
from __future__ import annotations

Expand Down Expand Up @@ -45,15 +46,21 @@
import triton.language as tl
from triton.runtime import autotuner
import triton._C.libtriton.triton as _triton
from triton.common.backend import get_backend
import triton.compiler.backends.cuda as cb

CAN_USE_TRITON = True
except ModuleNotFoundError:
pass
try:
from jax._src.lib import gpu_triton as triton_kernel_call_lib
except ImportError:
raise ValueError(
"Cannot import jaxlib triton library. You may need a newer version of jaxlib. Try installing a nightly wheel from: https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html or https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html"
)
"Cannot import jaxlib triton library. You may need a newer"
" version of jaxlib. Try installing a nightly wheel from:"
" https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html"
" or https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html"
)

os.environ["TRITON_CACHE_DIR"] = ""
map, unsafe_map = util.safe_map, map
Expand Down Expand Up @@ -96,17 +103,18 @@ def normalize_grid(grid: GridOrLambda, metaparams) -> Tuple[int, int, int]:
def avals_to_layouts(avals):
return [list(reversed(range(aval.ndim))) for aval in avals]


def get_triton_type(obj: Any) -> str:
if isinstance(obj, (jax.core.ShapedArray, state.AbstractRef)):
return f"*{_JAX_TO_TRITON_TYPE_MAP[obj.dtype]}"
if isinstance(obj, tl.constexpr):
obj = obj.value
if isinstance(obj, int):
if -2**31 <= obj < 2**31:
if -(2**31) <= obj < 2**31:
return "i32"
elif 2**31 <= obj < 2**32:
return "u32"
elif -2**63 <= obj < 2**63:
elif -(2**63) <= obj < 2**63:
return "i64"
elif 2**63 <= obj < 2**64:
return "u64"
Expand All @@ -131,7 +139,8 @@ def get_triton_type(obj: Any) -> str:
triton_kernel_call_p = jax.core.Primitive("triton_kernel_call")
triton_kernel_call_p.multiple_results = True
triton_kernel_call_p.def_impl(
functools.partial(xla.apply_primitive, triton_kernel_call_p))
functools.partial(xla.apply_primitive, triton_kernel_call_p)
)


@triton_kernel_call_p.def_abstract_eval
Expand All @@ -147,65 +156,82 @@ def aval_size_bytes(aval):


def ptx_get_kernel_name(module) -> str:
return tc.get_kernel_name(module, pattern='// .globl')
return cb.get_kernel_name(module, pattern="// .globl")


def get_arch_default_num_warps(device_type):
if device_type in ["cuda", "hip"]:
num_warps = 4
else:
device_backend = get_backend(device_type)
assert device_backend
arch = device_backend.get_architecture_descriptor()
num_warps = arch["num_warps"]
return num_warps


def get_arch_default_num_stages(device_type, capability):
if device_type == "cuda":
num_stages = 3 if capability >= 75 else 2
else:
device_backend = get_backend(device_type)
assert device_backend
arch = device_backend.get_architecture_descriptor()
num_stages = arch["num_stages"]

return num_stages


def compile_ttir_to_ptx_inplace(
ttir,
cuda_backend: cb.CUDABackend,
cuda_options: cb.CUDAOptions,
device: int = 0,
device_type: str = "cuda",
num_warps: Optional[int] = None,
num_stages: Optional[int] = None,
num_ctas: int = 1,
enable_fp_fusion: bool = True,
enable_warp_specialization: bool = False,
enable_persistent: bool = False,
dump: bool = False,
) -> Tuple[str, str, int, int]:
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
if num_warps is None:
num_warps = tc.get_arch_default_num_warps(device_type)
if num_stages is None:
num_stages = tc.get_arch_default_num_stages(
if cuda_options.num_warps is None:
cuda_options.num_warps = get_arch_default_num_warps(device_type)
if cuda_options.num_stages is None:
cuda_options.num_stages = get_arch_default_num_stages(
device_type, capability=compute_capability
)
if dump:
if cuda_options.debug:
print(ttir)
try:
target = tc.CudaTargetDescriptor(
capability=compute_capability,
num_warps=num_warps,
enable_fp_fusion=enable_fp_fusion,
)
ttir = tc.optimize_ttir(ttir, target)
ttgir = tc.ttir_to_ttgir(ttir, num_warps, num_ctas=num_ctas, target=target)
ttgir = tc.optimize_ttgir(
ttgir,
num_stages,
num_warps,
num_ctas=num_ctas,
target=target,
cluster_info=_triton.ClusterInfo(),
enable_warp_specialization=enable_warp_specialization,
enable_persistent=enable_persistent,
optimize_epilogue=False,
metadata = dict()
opt_ttir = cuda_backend.make_ttir(ttir, metadata, cuda_options)
ttgir = cuda_backend.make_ttgir(
opt_ttir,
metadata,
cuda_options,
compute_capability,
)
except RuntimeError as e:
ttir.dump()
raise ValueError("TTIR->TTGIR pass failed!") from e
if dump:
if cuda_options.debug:
print(ttgir)
extern_libs = {}
try:
llir = tc.ttgir_to_llir(ttgir, extern_libs, target, _triton.TMAInfos())
llir = cuda_backend.make_llir(
ttgir,
metadata,
cuda_backend.parse_linker_options(dict()),
compute_capability,
)
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
shared_mem_bytes = _triton.get_shared_memory_size(ttgir)
if dump:
if cuda_options.debug:
print(llir)
ptx = tc.llir_to_ptx(llir, target)
if dump:
ptx = cuda_backend.make_ptx(
llir,
metadata,
cuda_options,
compute_capability,
)
if cuda_options.debug:
print(ptx)
name = ptx_get_kernel_name(ptx)
return ptx, name, shared_mem_bytes, compute_capability
Expand All @@ -230,29 +256,29 @@ def get_or_create_triton_kernel(
) -> Tuple[triton_kernel_call_lib.TritonKernel, Any]:
device_type = "cuda"
if num_warps is None:
num_warps = tc.get_arch_default_num_warps(device_type)
num_warps = get_arch_default_num_warps(device_type)

signature = dict(enumerate(arg_dtypes))
# TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers
# We assume that all arrays are aligned to 16 bytes, and Triton may use this
# assumption, unless array args are include in the `do_not_specialize` list.
# We replace array arguments with mock Torch tensors, to allow us to use
# `JITFunction._get_config` to get the specialization.
# `JITFunction._get_config` to get the specialization_attr.
mock_torch_tensor = types.SimpleNamespace(data_ptr=lambda: 16)
args_for_specialization = [mock_torch_tensor] * len(arg_dtypes)
args_for_specialization_attr = [mock_torch_tensor] * len(arg_dtypes)
for i, _, v in scalar_args:
args_for_specialization[i] = v
specialization = fn._get_config(*args_for_specialization) # pylint: disable=protected-access
args_for_specialization_attr[i] = v
specialization_attr = fn._get_config(*args_for_specialization_attr) # pylint: disable=protected-access

constants = {fn.arg_names.index(k): v for k, v in metaparams.items()}
constants.update({i: None for i, _, v in scalar_args if v is None})
constants.update({i: 1 for i in specialization.equal_to_1})
constants.update({i: 1 for i in specialization_attr.equal_to_1})

# Cache key should contain any parameter that can affect the compiler output.
cache_key = (
fn,
tuple(signature.items()),
specialization,
tuple(vars(specialization_attr).values()),
tuple(constants.items()),
num_warps,
num_stages,
Expand All @@ -269,22 +295,43 @@ def get_or_create_triton_kernel(
# general.
device = 0
arch = triton_kernel_call_lib.get_compute_capability(device)

target = ("cuda", arch)
cuda_backend = cb.CUDABackend(target)

cuda_options = cuda_backend.parse_compiler_options(
dict(
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
cluster_dims=(1, 1, 1),
enable_warp_specialization=enable_warp_specialization,
enable_persistent=enable_persistent,
optimize_epilogue=False,
debug=dump,
enable_fp_fusion=enable_fp_fusion,
)
)

module = code_gen.ast_to_ttir(
fn, signature, specialization, constants, debug=dump, target=arch
fn,
specialization=tc.ASTSource(
fn,
constants=constants,
signature=signature,
attrs=specialization_attr,
),
options=cuda_options,
)

ttir = str(module) # `module`` is compiled in-place, so copy TTIR here.
ptx, kernel_name, shared_mem_bytes, compute_capability = (
compile_ttir_to_ptx_inplace(
module,
cuda_backend,
cuda_options,
device=device,
device_type=device_type,
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
enable_fp_fusion=enable_fp_fusion,
enable_warp_specialization=enable_warp_specialization,
enable_persistent=enable_persistent,
dump=dump,
)
)

Expand All @@ -294,7 +341,7 @@ def get_or_create_triton_kernel(

_COMPILED_KERNEL_CACHE[cache_key] = kernel

return kernel, specialization
return kernel, specialization_attr


def triton_kernel_call_lowering(
Expand All @@ -320,7 +367,8 @@ def triton_kernel_call_lowering(
):
if jaxlib.version.__version_info__ < (0, 3, 22) and input_output_aliases:
raise NotImplementedError(
"`input_output_aliases` only supported on `jaxlib>=0.3.22")
"`input_output_aliases` only supported on `jaxlib>=0.3.22"
)

kernel_call_name = name
args = list(ctx.avals_in)
Expand Down Expand Up @@ -419,7 +467,7 @@ def prune_configs(configs, named_args):

kernel_calls = []
for params in config_params:
kernel, specialization = get_or_create_triton_kernel(
kernel, specialization_attr = get_or_create_triton_kernel(
fn,
arg_dtypes,
scalar_args,
Expand All @@ -440,10 +488,10 @@ def prune_configs(configs, named_args):
kernel_params.append(
triton_kernel_call_lib.create_array_parameter(
zeroed_params_with_sizes.get(i, 0),
16 if (i in specialization.divisible_by_16) else 0,
16 if (i in specialization_attr.divisible_by_16) else 0,
)
)
elif i not in specialization.equal_to_1:
elif i not in specialization_attr.equal_to_1:
kernel_params.append(
triton_kernel_call_lib.create_scalar_parameter(arg, dtype)
)
Expand Down Expand Up @@ -615,7 +663,8 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"`triton_call` is only available when `triton` is installed."
)
out_shape = tree_util.tree_map(
lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), out_shape)
lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), out_shape
)
flat_args, _ = tree_util.tree_flatten(args)
# TODO(sharadmv): check in_tree is flat (no Pytrees allowed in triton_call)
flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape)
Expand Down
6 changes: 3 additions & 3 deletions tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,15 +539,15 @@ def test_specialization(self):
pass # Error thrown as the mocked method's return value is invalid.

mock_compile.assert_called_once()
specialization = mock_compile.call_args.args[2]
specialization = mock_compile.call_args[1]['specialization']

# Pointers are assumed to divide by 16, as do `M`, `N`, `stride_{bk,cm}`.
# However, we've marked `a_ptr`, `M`, `stride_bk`, and `c_ptr` as "do not
# specialize", leaving `b_ptr`, `N`, and `stride_cm`.
self.assertEqual(specialization.divisible_by_16, (1, 3, 9))
self.assertEqual(specialization.attrs.divisible_by_16, (1, 3, 9))
# `stride_{ak,bn,cn}` equal 1, but we've marked `stride_ak` as "do not
# specialize" leaving `stride_{bn,cn}`.
self.assertEqual(specialization.equal_to_1, (8, 10))
self.assertEqual(specialization.attrs.equal_to_1, (8, 10))


if __name__ == "__main__":
Expand Down

0 comments on commit 71fc158

Please sign in to comment.