Skip to content

Commit

Permalink
returning np.empty and input dtype in all functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ohrechykha committed Sep 12, 2024
1 parent 8c0867a commit 67b5807
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/ragged/_spec_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def unique_all(x: array, /) -> tuple[array, array, array, array]:
x_flat = ak.ravel(x._impl) # pylint: disable=W0212
if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101
return unique_all_result(
values=ragged.array([]),
indices=ragged.array([]),
inverse_indices=ragged.array([]),
counts=ragged.array([]),
values=ragged.array(np.empty(0, x.dtype)),
indices=ragged.array(np.empty(0, np.int64)),
inverse_indices=ragged.array(np.empty(0, np.int64)),
counts=ragged.array(np.empty(0, np.int64)),
)
values, indices, inverse_indices, counts = np.unique(
x_flat.layout.data, # pylint: disable=E1101
Expand Down Expand Up @@ -123,7 +123,8 @@ def unique_counts(x: array, /) -> tuple[array, array]:
x_flat = ak.ravel(x._impl) # pylint: disable=W0212
if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101
return unique_counts_result(
values=ragged.array([]), counts=ragged.array([])
values=ragged.array(np.empty(0, x.dtype)),
counts=ragged.array(np.empty(0, np.int64)),
)
values, counts = np.unique(
x_flat.layout.data, # pylint: disable=E1101
Expand Down Expand Up @@ -174,7 +175,8 @@ def unique_inverse(x: array, /) -> tuple[array, array]:
x_flat = ak.ravel(x._impl) # pylint: disable=W0212
if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101
return unique_inverse_result(
values=ragged.array([]), inverse_indices=ragged.array([])
values=ragged.array(np.empty(0, x.dtype)),
inverse_indices=ragged.array(np.empty(0, np.int64)),
)
values, inverse_indices = np.unique(
x_flat.layout.data, # pylint: disable=E1101
Expand Down Expand Up @@ -213,7 +215,7 @@ def unique_values(x: array, /) -> array:
else:
x_flat = ak.ravel(x._impl) # pylint: disable=W0212
if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101
return ragged.array([])
return ragged.array(np.empty(0, x.dtype))
return ragged.array(np.unique(x_flat.layout.data, equal_nan=False)) # pylint: disable=E1101
else:
err = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable]
Expand Down

0 comments on commit 67b5807

Please sign in to comment.