Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: set functions and tests #57

Merged
merged 21 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
9225efb
adding set functions and tests
ohrechykha Jul 30, 2024
14c58ac
pushing pre-commit changes
ohrechykha Jul 31, 2024
24ddbfd
ruff fixes for test_spec_set_functions.py
ohrechykha Aug 5, 2024
22fe8b9
further fixes in test_spec_set_functions.py
ohrechykha Aug 5, 2024
9153357
fixing mypy unreachable errors in _spec_set_functions.py
ohrechykha Aug 5, 2024
6c814a7
marking tests with None and empty arrays as comments
ohrechykha Aug 5, 2024
83dabd9
adding namedtuple & corresponding test fixes
ohrechykha Aug 19, 2024
2e92e46
Merge remote-tracking branch 'origin/main' into oleksii-unique
ohrechykha Aug 22, 2024
1ca9106
function if changes + test standartization
ohrechykha Aug 26, 2024
f04ccbf
correcting function ifs + test standartization
ohrechykha Aug 27, 2024
7b30c58
further test standartisation
ohrechykha Aug 28, 2024
7fb48f1
scalar handling and testing
ohrechykha Aug 29, 2024
6e8b6ee
_array_object changes, empty array handling + tests
ohrechykha Sep 10, 2024
36dfb42
implementing Jim's suggestion, disabling CI errors
ohrechykha Sep 10, 2024
e5852b3
disabling errors and warnings
ohrechykha Sep 10, 2024
1ff29b9
further ignores + adding equal_nan in np.unique instances
ohrechykha Sep 10, 2024
48893b8
better ignores
ohrechykha Sep 10, 2024
943d717
improving ignores
ohrechykha Sep 10, 2024
f719e44
Merge branch 'main' into oleksii-unique
jpivarski Sep 11, 2024
8c0867a
avoiding code duplication in _spec_array_object
ohrechykha Sep 12, 2024
67b5807
returning np.empty and input dtype in all functions
ohrechykha Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ragged/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
nonzero,
where,
)
from ._spec_set_functions import (
from ._spec_set_functions import ( # pylint: disable=R0401
unique_all,
unique_counts,
unique_inverse,
Expand Down
6 changes: 5 additions & 1 deletion src/ragged/_spec_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
from awkward.contents import (
Content,
EmptyArray,
ListArray,
ListOffsetArray,
NumpyArray,
Expand Down Expand Up @@ -44,7 +45,10 @@ def _shape_dtype(layout: Content) -> tuple[Shape, Dtype]:
else:
shape = (*shape, None)
node = node.content

if isinstance(node, EmptyArray):
node = node.to_NumpyArray(dtype=np.float64)
shape = shape + node.data.shape[1:]
return shape, node.data.dtype
if isinstance(node, NumpyArray):
shape = shape + node.data.shape[1:]
return shape, node.data.dtype
ohrechykha marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
109 changes: 98 additions & 11 deletions src/ragged/_spec_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

from collections import namedtuple

import awkward as ak
import numpy as np

import ragged

from ._spec_array_object import array

unique_all_result = namedtuple( # pylint: disable=C0103
Expand Down Expand Up @@ -47,8 +52,39 @@ def unique_all(x: array, /) -> tuple[array, array, array, array]:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_all.html
"""

x # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 128") # noqa: EM101
if isinstance(x, ragged.array):
if x.ndim == 0:
return unique_all_result(
values=ragged.array(np.unique(x._impl, equal_nan=False)), # pylint: disable=W0212
indices=ragged.array([0]),
inverse_indices=ragged.array([0]),
counts=ragged.array([1]),
)
else:
x_flat = ak.ravel(x._impl) # pylint: disable=W0212
if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101
return unique_all_result(
values=ragged.array([]),
indices=ragged.array([]),
inverse_indices=ragged.array([]),
counts=ragged.array([]),
ohrechykha marked this conversation as resolved.
Show resolved Hide resolved
)
values, indices, inverse_indices, counts = np.unique(
x_flat.layout.data, # pylint: disable=E1101
return_index=True,
return_inverse=True,
return_counts=True,
equal_nan=False,
)
return unique_all_result(
values=ragged.array(values),
indices=ragged.array(indices),
inverse_indices=ragged.array(inverse_indices),
counts=ragged.array(counts),
)
else:
msg = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable]
raise TypeError(msg)


unique_counts_result = namedtuple( # pylint: disable=C0103
Expand Down Expand Up @@ -77,9 +113,29 @@ def unique_counts(x: array, /) -> tuple[array, array]:

https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_counts.html
"""

x # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 129") # noqa: EM101
if isinstance(x, ragged.array):
if x.ndim == 0:
return unique_counts_result(
values=ragged.array(np.unique(x._impl, equal_nan=False)), # pylint: disable=W0212
counts=ragged.array([1]), # pylint: disable=W0212
)
else:
x_flat = ak.ravel(x._impl) # pylint: disable=W0212
if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101
return unique_counts_result(
values=ragged.array([]), counts=ragged.array([])
ohrechykha marked this conversation as resolved.
Show resolved Hide resolved
)
values, counts = np.unique(
x_flat.layout.data, # pylint: disable=E1101
return_counts=True,
equal_nan=False,
)
return unique_counts_result(
values=ragged.array(values), counts=ragged.array(counts)
)
else:
msg = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable]
raise TypeError(msg)


unique_inverse_result = namedtuple( # pylint: disable=C0103
Expand Down Expand Up @@ -108,9 +164,31 @@ def unique_inverse(x: array, /) -> tuple[array, array]:

https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_inverse.html
"""

x # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 130") # noqa: EM101
if isinstance(x, ragged.array):
if x.ndim == 0:
return unique_inverse_result(
values=ragged.array(np.unique(x._impl, equal_nan=False)), # pylint: disable=W0212
inverse_indices=ragged.array([0]),
)
else:
x_flat = ak.ravel(x._impl) # pylint: disable=W0212
if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101
return unique_inverse_result(
values=ragged.array([]), inverse_indices=ragged.array([])
ohrechykha marked this conversation as resolved.
Show resolved Hide resolved
)
values, inverse_indices = np.unique(
x_flat.layout.data, # pylint: disable=E1101
return_inverse=True,
equal_nan=False,
)

return unique_inverse_result(
values=ragged.array(values),
inverse_indices=ragged.array(inverse_indices),
)
else:
msg = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable]
raise TypeError(msg)


def unique_values(x: array, /) -> array:
Expand All @@ -128,6 +206,15 @@ def unique_values(x: array, /) -> array:

https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_values.html
"""

x # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 131") # noqa: EM101
if isinstance(x, ragged.array):
if x.ndim == 0:
return ragged.array(np.unique(x._impl, equal_nan=False)) # pylint: disable=W0212

else:
x_flat = ak.ravel(x._impl) # pylint: disable=W0212
if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101
return ragged.array([])
ohrechykha marked this conversation as resolved.
Show resolved Hide resolved
return ragged.array(np.unique(x_flat.layout.data, equal_nan=False)) # pylint: disable=E1101
else:
err = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable]
raise TypeError(err)
Loading
Loading