Skip to content

Commit

Permalink
fix imports
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Feb 8, 2024
1 parent ee50629 commit 510d8b6
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tests/test_myarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import jax.numpy as jnp
import pytest
from jax import Array
from jax.experimental.array_api._data_type_functions import FInfo, IInfo
from jax.experimental.array_api._set_functions import (
UniqueAllResult,
UniqueCountsResult,
UniqueInverseResult,
from jax._src.numpy.setops import (
_UniqueAllResult,
_UniqueCountsResult,
_UniqueInverseResult,
)
from jax.experimental.array_api._data_type_functions import FInfo, IInfo
from myarray import MyArray

import array_api_jax_compat as xp
Expand Down Expand Up @@ -1103,7 +1103,7 @@ def test_unique_all():
got = xp.unique_all(x)
expected = jax_xp.unique_all(x.array)

assert isinstance(got, UniqueAllResult)
assert isinstance(got, _UniqueAllResult)

assert isinstance(got.values, MyArray)
assert jnp.array_equal(got.values, expected.values)
Expand All @@ -1125,7 +1125,7 @@ def test_unique_counts():
got = xp.unique_counts(x)
expected = jax_xp.unique_counts(x.array)

assert isinstance(got, UniqueCountsResult)
assert isinstance(got, _UniqueCountsResult)

assert isinstance(got.values, MyArray)
assert jnp.array_equal(got.values.array, expected.values)
Expand All @@ -1141,7 +1141,7 @@ def test_unique_inverse():
got = xp.unique_inverse(x)
expected = jax_xp.unique_inverse(x.array)

assert isinstance(got, UniqueInverseResult)
assert isinstance(got, _UniqueInverseResult)

assert isinstance(got.values, MyArray)
assert jnp.array_equal(got.values.array, expected.values)
Expand Down

0 comments on commit 510d8b6

Please sign in to comment.