Skip to content

Commit

Permalink
correcting function ifs + test standartization
Browse files Browse the repository at this point in the history
  • Loading branch information
ohrechykha committed Aug 27, 2024
1 parent 1ca9106 commit f04ccbf
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
20 changes: 9 additions & 11 deletions src/ragged/_spec_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def unique_all(x: array, /) -> tuple[array, array, array, array]:
counts=ragged.array(counts),
)
else:
msg = f"Expected ragged type but got {type(x)}" # type: ignore
raise TypeError(msg) # type: ignore
msg = f"Expected ragged type but got {type(x)}" # type: ignore
raise TypeError(msg) # type: ignore


unique_counts_result = namedtuple( # pylint: disable=C0103
Expand Down Expand Up @@ -118,8 +118,8 @@ def unique_counts(x: array, /) -> tuple[array, array]:
values=ragged.array(values), counts=ragged.array(counts)
)
else:
msg = f"Expected ragged type but got {type(x)}" # type: ignore
raise TypeError(msg) # type: ignore
msg = f"Expected ragged type but got {type(x)}" # type: ignore
raise TypeError(msg) # type: ignore


unique_inverse_result = namedtuple( # pylint: disable=C0103
Expand Down Expand Up @@ -150,9 +150,7 @@ def unique_inverse(x: array, /) -> tuple[array, array]:
"""
if isinstance(x, ragged.array):
if ak.is_scalar(x):
return unique_inverse_result(
values=x, inverse_indices=ragged.array([0])
)
return unique_inverse_result(values=x, inverse_indices=ragged.array([0]))
else:
x_flat = ak.ravel(x._impl)
values, inverse_indices = np.unique(x_flat.layout.data, return_inverse=True)
Expand All @@ -162,8 +160,8 @@ def unique_inverse(x: array, /) -> tuple[array, array]:
inverse_indices=ragged.array(inverse_indices),
)
else:
msg = f"Expected ragged type but got {type(x)}" # type: ignore
raise TypeError(msg) # type: ignore
msg = f"Expected ragged type but got {type(x)}" # type: ignore
raise TypeError(msg) # type: ignore


def unique_values(x: array, /) -> array:
Expand All @@ -189,5 +187,5 @@ def unique_values(x: array, /) -> array:
x_flat = ak.ravel(x._impl)
return ragged.array(np.unique(x_flat.layout.data))
else:
err = f"Expected ragged type but got {type(x)}" # type: ignore
raise TypeError(err) # type: ignore
err = f"Expected ragged type but got {type(x)}" # type: ignore
raise TypeError(err) # type: ignore
29 changes: 16 additions & 13 deletions tests/test_spec_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def test_can_take_list():

def test_can_take_empty_arr():
# with pytest.raises(TypeError):
assert ak.to_list(ragged.unique_values(ragged.array([]))) == ak.to_list(
ragged.array([])
)
assert ak.to_list(ragged.unique_values(ragged.array([]))) == ak.to_list(
ragged.array([])
)


def test_can_take_moredimensions():
Expand Down Expand Up @@ -99,13 +99,12 @@ def test_can_count_scalar():


def test_can_inverse_list():
arr=ragged.array([1, 2, 4, 3, 4, 5, 6, 20])
expected_values=ragged.array([1,2,3,4,5,6,20])
expected_inverse=ragged.array([0, 1, 3, 2, 3, 4, 5, 6])
values, inverse= ragged.unique_inverse(arr)
assert ak.to_list(expected_values)==ak.to_list(values)
assert ak.to_list(expected_inverse)==ak.to_list(inverse)

arr = ragged.array([1, 2, 4, 3, 4, 5, 6, 20])
expected_values = ragged.array([1, 2, 3, 4, 5, 6, 20])
expected_inverse = ragged.array([0, 1, 3, 2, 3, 4, 5, 6])
values, inverse = ragged.unique_inverse(arr)
assert ak.to_list(expected_values) == ak.to_list(values)
assert ak.to_list(expected_inverse) == ak.to_list(inverse)


def test_can_inverse_simple_array():
Expand Down Expand Up @@ -138,19 +137,23 @@ def test_can_inverse_scalar():
# unique_all tests
def test_can_all_none():
with pytest.raises(TypeError):
arr=None
arr = None
expected_unique_values = ragged.array(None)
expected_unique_indices = ragged.array(None)
expected_unique_inverse = ragged.array(None)
expected_unique_counts = ragged.array(None)
unique_values, unique_indices, unique_inverse, unique_counts = ragged.unique_all(arr)
(
unique_values,
unique_indices,
unique_inverse,
unique_counts,
) = ragged.unique_all(arr)
assert ak.to_list(unique_values) == ak.to_list(expected_unique_values)
assert ak.to_list(unique_indices) == ak.to_list(expected_unique_indices)
assert ak.to_list(unique_inverse) == ak.to_list(expected_unique_inverse)
assert ak.to_list(unique_counts) == ak.to_list(expected_unique_counts)



def test_can_all_list():
arr = ragged.array([1, 2, 2, 3, 3, 3, 4, 4, 4, 4])
expected_unique_values = ragged.array([1, 2, 3, 4])
Expand Down

0 comments on commit f04ccbf

Please sign in to comment.