diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 00000000..e69de29b diff --git a/404.html b/404.html new file mode 100644 index 00000000..fd87329e --- /dev/null +++ b/404.html @@ -0,0 +1,277 @@ + + + +
+ + + + + + + + + + + + + + + + + +JAX-Triton is a repository containing containing integrations between JAX +and Triton.
+JAX is a Python library for accelerated numerical computing and Triton is a Python library and compiler for writing custom GPU kernels. +When we put the two together, we get JAX-Triton, which enables writing custom GPU kernels using Triton that can be embedded inside of JAX programs.
+You can install JAX-Triton with pip
. This will also install a compatible JAX and Triton.
+
JAX-Triton only works with JAX on GPU, so you'll need to make sure you have a CUDA-compatible jaxlib
installed.
+For example you could run:
+
$ pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+
JAX-Triton and Pallas are developed at JAX and Jaxlib HEAD and close to Triton HEAD. To get a bleeding edge installation of JAX-Triton, run: +
+This should install compatible versions of JAX and Triton. +JAX-Triton does depend on Jaxlib but it's usually a more stable dependency. You might be able to get away with using a recent jaxlib release: +
$ pip install jaxlib[cuda]
+$ # or
+$ pip install jaxlib[cuda11_pip]
+$ # or
+$ pip install jaxlib[cuda12_pip]
+
If you find there are issues with the latest Jaxlib release, you can try using a Jaxlib nightly. +To install a new jaxlib, you can find a link to a CUDA 11 nightly or CUDA 12 nightly. Then install it via: +
+or to install CUDA via pip automatically, you can do: +$ pip install 'jaxlib[cuda11_pip] @ <link to nightly>'
+$ # or
+$ pip install 'jaxlib[cuda12_pip] @ <link to nightly>'
+
The main function of interest is jax_triton.triton_call
for applying Triton
+functions to JAX arrays, including inside jax.jit
-compiled functions. For
+example, we can define a kernel from the Triton
+tutorial:
import triton
+import triton.language as tl
+
+
+@triton.jit
+def add_kernel(
+ x_ptr,
+ y_ptr,
+ output_ptr,
+ block_size: tl.constexpr,
+):
+ """Adds two vectors."""
+ pid = tl.program_id(axis=0)
+ block_start = pid * block_size
+ offsets = block_start + tl.arange(0, block_size)
+ mask = offsets < 8
+ x = tl.load(x_ptr + offsets, mask=mask)
+ y = tl.load(y_ptr + offsets, mask=mask)
+ output = x + y
+ tl.store(output_ptr + offsets, output, mask=mask)
+
Then we can apply it to JAX arrays using jax_triton.triton_call
:
import jax
+import jax.numpy as jnp
+import jax_triton as jt
+
+def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
+ out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
+ block_size = 8
+ return jt.triton_call(
+ x,
+ y,
+ kernel=add_kernel,
+ out_shape=out_shape,
+ grid=(x.size // block_size,),
+ block_size=block_size)
+
+x_val = jnp.arange(8)
+y_val = jnp.arange(8, 16)
+print(add(x_val, y_val))
+print(jax.jit(add)(x_val, y_val))
+
See the examples +directory, especially +fused_attention.py +and the fused attention +ipynb.
+ + + + + + +The primary way of using JAX Triton is using jax_triton.triton_call
to call handwritten Triton kernels
+from inside JIT-ted JAX programs.
jax_triton.triton_call
+
+Calls a Triton kernel with jax.Array
arguments.
Example usage:
+First we define a simple kernel that adds two vectors.
+import triton
+import triton.language as tl
+
+@triton.jit
+def add_kernel(
+ x_ptr,
+ y_ptr,
+ output_ptr,
+ block_size: tl.constexpr,
+):
+ pid = tl.program_id(axis=0)
+ block_start = pid * block_size
+ offsets = block_start + tl.arange(0, block_size)
+ mask = offsets < 8
+ x = tl.load(x_ptr + offsets, mask=mask)
+ y = tl.load(y_ptr + offsets, mask=mask)
+ output = x + y
+ tl.store(output_ptr + offsets, output, mask=mask)
+
Then we use triton_call
to call it from JAX.
import jax
+import jax.numpy as jnp
+import jax_triton as jt
+
+def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
+ out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
+ block_size = 8
+ return jt.triton_call(
+ x,
+ y,
+ kernel=add_kernel,
+ out_shape=out_shape,
+ grid=(x.size // block_size,),
+ block_size=block_size)
+
+x_val = jnp.arange(8)
+y_val = jnp.arange(8, 16)
+print(add(x_val, y_val))
+print(jax.jit(add)(x_val, y_val))
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
*args |
+
+ Union[jax.Array, bool, int, float]
+ |
+ Inputs for the Triton kernel. |
+
+ ()
+ |
+
kernel |
+
+ triton.JITFunction
+ |
+ A Triton kernel (e.g. a function decorated with |
+ + required + | +
out_shape |
+
+ Union[ShapeDtype, Sequence[ShapeDtype]]
+ |
+ A |
+ + required + | +
grid |
+
+ GridOrLambda
+ |
+ An integer, tuple of up to 3 integers, or a function that returns a
+tuple of up to 3 integers. When |
+ + required + | +
input_output_aliases |
+
+ Optional[Dict[int, int]]
+ |
+ A dictionary mapping input argument indices to output +indices. Providing a mapping will alias the corresponding buffers. |
+
+ None
+ |
+
zeroed_outputs |
+
+ Union[Sequence[int], Callable[[Dict[str, Any]], Sequence[int]]]
+ |
+ A sequence of indices, or a function returning a sequence of +indices, for outputs that should be zeroed before the kernel is launched. |
+
+ ()
+ |
+
num_warps |
+
+ int
+ |
+ The number of warps used to execute the Triton kernel. |
+
+ 4
+ |
+
num_stages |
+
+ int
+ |
+ The number of stages emitted by the Triton compiler. |
+
+ 2
+ |
+
debug |
+
+ bool
+ |
+ Prints out intermediate IRs if True for debugging purposes. |
+
+ False
+ |
+
serialized_metadata |
+
+ bytes
+ |
+ Arbitrary metadata that will be added into the +serialized kernel call. |
+
+ b''
+ |
+
**metaparams |
+
+ Any
+ |
+ Additional keyword arguments that will be provided to a |
+
+ {}
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Any
+ |
+ Outputs from the Triton kernel. |
+