diff --git a/tests/triton_call_test.py b/tests/triton_call_test.py index 71a5980..ebdf41e 100644 --- a/tests/triton_call_test.py +++ b/tests/triton_call_test.py @@ -18,8 +18,8 @@ from absl.testing import absltest from absl.testing import parameterized import jax +from jax import config from jax import random -from jax.config import config import jax.numpy as jnp import jax_triton as jt import numpy as np