Skip to content

Commit

Permalink
further ignores + adding equal_nan in np.unique instances
Browse files Browse the repository at this point in the history
  • Loading branch information
ohrechykha committed Sep 10, 2024
1 parent e5852b3 commit 1ff29b9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/ragged/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
nonzero,
where,
)
from ._spec_set_functions import ( # pylint: disable=W0622
from ._spec_set_functions import ( # pylint: disable=R0401
unique_all,
unique_counts,
unique_inverse,
Expand Down
18 changes: 11 additions & 7 deletions src/ragged/_spec_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def unique_all(x: array, /) -> tuple[array, array, array, array]:
if isinstance(x, ragged.array):
if x.ndim == 0:
return unique_all_result(
values=ragged.array(np.unique(x._impl)), # pylint: disable=W0212
values=ragged.array(np.unique(x._impl, equal_nan=False)), # pylint: disable=W0212
indices=ragged.array([0]),
inverse_indices=ragged.array([0]),
counts=ragged.array([1]),
Expand Down Expand Up @@ -116,7 +116,7 @@ def unique_counts(x: array, /) -> tuple[array, array]:
if isinstance(x, ragged.array):
if x.ndim == 0:
return unique_counts_result(
values=ragged.array(np.unique(x._impl)),
values=ragged.array(np.unique(x._impl, equal_nan=False)), # pylint: disable=W0212
counts=ragged.array([1]), # pylint: disable=W0212
)
else:
Expand All @@ -125,7 +125,9 @@ def unique_counts(x: array, /) -> tuple[array, array]:
return unique_counts_result(
values=ragged.array([]), counts=ragged.array([])
)
values, counts = np.unique(x_flat.layout.data, return_counts=True) # pylint: disable=E1101
values, counts = np.unique(
x_flat.layout.data, return_counts=True, equal_nan=False
) # pylint: disable=E1101
return unique_counts_result(
values=ragged.array(values), counts=ragged.array(counts)
)
Expand Down Expand Up @@ -163,7 +165,7 @@ def unique_inverse(x: array, /) -> tuple[array, array]:
if isinstance(x, ragged.array):
if x.ndim == 0:
return unique_inverse_result(
values=ragged.array(np.unique(x._impl)), # pylint: disable=W0212
values=ragged.array(np.unique(x._impl, equal_nan=False)), # pylint: disable=W0212
inverse_indices=ragged.array([0]),
)
else:
Expand All @@ -172,7 +174,9 @@ def unique_inverse(x: array, /) -> tuple[array, array]:
return unique_inverse_result(
values=ragged.array([]), inverse_indices=ragged.array([])
)
values, inverse_indices = np.unique(x_flat.layout.data, return_inverse=True) # pylint: disable=E1101
values, inverse_indices = np.unique(
x_flat.layout.data, return_inverse=True, equal_nan=False
) # pylint: disable=E1101

return unique_inverse_result(
values=ragged.array(values),
Expand Down Expand Up @@ -200,13 +204,13 @@ def unique_values(x: array, /) -> array:
"""
if isinstance(x, ragged.array):
if x.ndim == 0:
return ragged.array(np.unique(x._impl)) # pylint: disable=W0212
return ragged.array(np.unique(x._impl, equal_nan=False)) # pylint: disable=W0212

else:
x_flat = ak.ravel(x._impl) # pylint: disable=W0212
if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101
return ragged.array([])
return ragged.array(np.unique(x_flat.layout.data)) # pylint: disable=E1101
return ragged.array(np.unique(x_flat.layout.data, equal_nan=False)) # pylint: disable=E1101
else:
err = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable]
raise TypeError(err)

0 comments on commit 1ff29b9

Please sign in to comment.