From acbdba3883c6dfa8d61ae4bc0d8ff0a6597223e6 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 30 Jun 2023 03:22:48 -0700 Subject: [PATCH] Add `get_serialized_metadata` function to retrieve metadata from op's opaque data. PiperOrigin-RevId: 544608895 --- jax_triton/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/jax_triton/__init__.py b/jax_triton/__init__.py index 84f6b621..298804a3 100644 --- a/jax_triton/__init__.py +++ b/jax_triton/__init__.py @@ -13,16 +13,20 @@ # limitations under the License. """Library for JAX-Triton integrations.""" +import jaxlib +from jax._src.lib import gpu_triton +from jax_triton import pallas +from jax_triton.triton_lib import triton_call from jax_triton.utils import cdiv from jax_triton.utils import next_power_of_2 from jax_triton.utils import strides_from_shape -from jax_triton.triton_lib import triton_call from jax_triton.version import __version__ from jax_triton.version import __version_info__ -from jax_triton import pallas -from jax._src.lib import gpu_triton get_compute_capability = gpu_triton.get_compute_capability +if jaxlib.version.__version_info__ >= (0, 4, 14): + get_serialized_metadata = gpu_triton.get_serialized_metadata # trailer del gpu_triton +del jaxlib