Skip to content

Commit

Permalink
Merge pull request #191 from jax-ml:jax-pin
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 548220849
  • Loading branch information
The jax_triton Authors committed Jul 14, 2023
2 parents 56ffd00 + bd0611e commit ec5b82d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
5 changes: 4 additions & 1 deletion jax_triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@

get_compute_capability = gpu_triton.get_compute_capability
if jaxlib.version.__version_info__ >= (0, 4, 14):
get_serialized_metadata = gpu_triton.get_serialized_metadata
try:
get_serialized_metadata = gpu_triton.get_serialized_metadata
except AttributeError:
get_serialized_metadata = None

# trailer
del gpu_triton
Expand Down
24 changes: 21 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,31 @@ name = "jax-triton"
dynamic = ["version"]
description = "JAX + OpenAI Triton integration"
readme = "README.md"
requires-python = ">=3.8,<3.11"
requires-python = ">=3.9,<3.11"
dependencies = [
"absl-py>=1.4.0",
"jax @ git+https://github.com/google/jax@d2d30bc4fdd721632af5d0974588cfd71d10b54b",
"triton @ git+https://github.com/openai/triton@acf1ede5bfd0729bb99e848fb8edace0a24da8d4#subdirectory=python",
"jax @ git+https://github.com/google/jax@f7eef2eda8b2d36b7d6f928de0e8e726098bcf62",
"triton-nightly @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/07c94329-d4c3-4ad4-9e6b-f904a60032ec/pypi/download/triton-nightly/2.1.dev20230714011643/triton_nightly-2.1.0.dev20230714011643-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
]

[project.optional-dependencies]
cuda12 = [
"jaxlib@ https://storage.googleapis.com/jax-releases/nightly/cuda12/jaxlib-0.4.14.dev20230714+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"
]
cuda12_pip = [
"jaxlib[cuda12_pip] @ https://storage.googleapis.com/jax-releases/nightly/cuda12/jaxlib-0.4.14.dev20230714+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"
]
cuda11 = [
"jaxlib @ https://storage.googleapis.com/jax-releases/nightly/cuda118/jaxlib-0.4.14.dev20230714+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"
]
cuda11 = [
"jaxlib[cuda11_pip] @ https://storage.googleapis.com/jax-releases/nightly/cuda118/jaxlib-0.4.14.dev20230714+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"
]
tests = [
"pytest"
]


[build-system]
requires = ["setuptools", "setuptools-scm", "cmake"]
build-backend = "setuptools.build_meta"
Expand Down

0 comments on commit ec5b82d

Please sign in to comment.