From 510d8b6f49fb434bfdadd02feaeafdac833ae8b9 Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 7 Feb 2024 23:33:32 -0500 Subject: [PATCH] fix imports Signed-off-by: nstarman --- tests/test_myarray.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_myarray.py b/tests/test_myarray.py index e9eb8e9..d077377 100644 --- a/tests/test_myarray.py +++ b/tests/test_myarray.py @@ -4,12 +4,12 @@ import jax.numpy as jnp import pytest from jax import Array -from jax.experimental.array_api._data_type_functions import FInfo, IInfo -from jax.experimental.array_api._set_functions import ( - UniqueAllResult, - UniqueCountsResult, - UniqueInverseResult, +from jax._src.numpy.setops import ( + _UniqueAllResult, + _UniqueCountsResult, + _UniqueInverseResult, ) +from jax.experimental.array_api._data_type_functions import FInfo, IInfo from myarray import MyArray import array_api_jax_compat as xp @@ -1103,7 +1103,7 @@ def test_unique_all(): got = xp.unique_all(x) expected = jax_xp.unique_all(x.array) - assert isinstance(got, UniqueAllResult) + assert isinstance(got, _UniqueAllResult) assert isinstance(got.values, MyArray) assert jnp.array_equal(got.values, expected.values) @@ -1125,7 +1125,7 @@ def test_unique_counts(): got = xp.unique_counts(x) expected = jax_xp.unique_counts(x.array) - assert isinstance(got, UniqueCountsResult) + assert isinstance(got, _UniqueCountsResult) assert isinstance(got.values, MyArray) assert jnp.array_equal(got.values.array, expected.values) @@ -1141,7 +1141,7 @@ def test_unique_inverse(): got = xp.unique_inverse(x) expected = jax_xp.unique_inverse(x.array) - assert isinstance(got, UniqueInverseResult) + assert isinstance(got, _UniqueInverseResult) assert isinstance(got.values, MyArray) assert jnp.array_equal(got.values.array, expected.values)