Skip to content

Commit

Permalink
Add get_serialized_metadata function to retrieve metadata from op's…
Browse files Browse the repository at this point in the history
… opaque data.

PiperOrigin-RevId: 544608895
  • Loading branch information
chr1sj0nes authored and The jax_triton Authors committed Jun 30, 2023
1 parent b6897bc commit acbdba3
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions jax_triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@
# limitations under the License.

"""Library for JAX-Triton integrations."""
import jaxlib
from jax._src.lib import gpu_triton
from jax_triton import pallas
from jax_triton.triton_lib import triton_call
from jax_triton.utils import cdiv
from jax_triton.utils import next_power_of_2
from jax_triton.utils import strides_from_shape
from jax_triton.triton_lib import triton_call
from jax_triton.version import __version__
from jax_triton.version import __version_info__
from jax_triton import pallas
from jax._src.lib import gpu_triton

get_compute_capability = gpu_triton.get_compute_capability
if jaxlib.version.__version_info__ >= (0, 4, 14):
get_serialized_metadata = gpu_triton.get_serialized_metadata

# trailer
del gpu_triton
del jaxlib

0 comments on commit acbdba3

Please sign in to comment.