Skip to content

Commit

Permalink
[data] feat: Implement multi-directional sort over Ray Data datasets (r…
Browse files Browse the repository at this point in the history
  • Loading branch information
sumanthratna authored Dec 24, 2024
1 parent 8154528 commit f186d61
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 26 deletions.
28 changes: 15 additions & 13 deletions python/ray/data/_internal/planner/exchange/sort_task_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ def __init__(
raise ValueError(
"Length of `descending` does not match the length of the key."
)
if len(set(descending)) != 1:
raise ValueError("Sorting with mixed key orders not supported yet.")
self._columns = key
self._descending = descending
if boundaries:
Expand All @@ -58,17 +56,17 @@ def __init__(
def get_columns(self) -> List[str]:
return self._columns

def get_descending(self) -> bool:
return self._descending[0]
def get_descending(self) -> List[bool]:
return self._descending

def to_arrow_sort_args(self) -> List[Tuple[str, str]]:
return [
(key, "descending" if self._descending[0] else "ascending")
for key in self._columns
(key, "descending" if desc else "ascending")
for key, desc in zip(self._columns, self._descending)
]

def to_pandas_sort_args(self) -> Tuple[List[str], bool]:
return self._columns, not self._descending[0]
def to_pandas_sort_args(self) -> Tuple[List[str], List[bool]]:
return self._columns, [not desc for desc in self._descending]

def validate_schema(self, schema: Optional[Union[type, "pyarrow.lib.Schema"]]):
"""Check the key function is valid on the given schema."""
Expand Down Expand Up @@ -209,11 +207,15 @@ def is_na(x):
return np.isnan(x)
return False

def key_fn_with_nones(sample):
return tuple(NULL_SENTINEL if is_na(x) else x for x in sample)

# Sort the list, but Nones should be NULL_SENTINEL to ensure safe sorting.
samples_list = sorted(samples_list, key=key_fn_with_nones)
# To allow multi-directional sort, we utilize Python's stable sort: we
# sort several times with different directions. We do this in reverse, so
# that the last key we sort by is the primary sort key passed by the user.
for i, desc in list(enumerate(sort_key.get_descending()))[::-1]:
# Sort the list, but Nones should be NULL_SENTINEL to ensure safe sorting.
samples_list.sort(
key=lambda sample: NULL_SENTINEL if is_na(sample[i]) else sample[i],
reverse=desc,
)

# Each boundary corresponds to a quantile of the data.
quantile_indices = [
Expand Down
8 changes: 5 additions & 3 deletions python/ray/data/_internal/planner/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ def fn(
blocks, sort_key, num_outputs, sample_bar
)
else:
# For user-specified boundaries (which only partition by the primary
# sort key), reverse `boundaries` so that the partitions are produced
# in descending order, as desired.
boundaries = [(b,) for b in sort_key.boundaries]
if sort_key.get_descending()[0]:
boundaries = boundaries[::-1]
num_outputs = len(boundaries) + 1
_, ascending = sort_key.to_pandas_sort_args()
if not ascending:
boundaries.reverse()
sort_spec = SortTaskSpec(
boundaries=boundaries, sort_key=sort_key, batch_format=batch_format
)
Expand Down
38 changes: 34 additions & 4 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,9 +720,29 @@ def unify_block_metadata_schema(

def find_partition_index(
table: Union["pyarrow.Table", "pandas.DataFrame"],
desired: List[Any],
desired: Tuple[Union[int, float]],
sort_key: "SortKey",
) -> int:
"""For the given block, find the index where the desired value should be
added, to maintain sorted order.
We do this by iterating over each column, starting with the primary sort key,
and binary searching for the desired value in the column. Each binary search
shortens the "range" of indices (represented by ``left`` and ``right``, which
are indices of rows) where the desired value could be inserted.
Args:
table: The block to search in.
desired: A single tuple representing the boundary to partition at.
``len(desired)`` must be less than or equal to the number of columns
being sorted.
sort_key: The sort key to use for sorting, providing the columns to be
sorted and their directions.
Returns:
The index where the desired value should be inserted to maintain sorted
order.
"""
columns = sort_key.get_columns()
descending = sort_key.get_descending()

Expand All @@ -745,7 +765,13 @@ def find_partition_index(
col_vals[null_mask] = NULL_SENTINEL

prevleft = left
if descending is True:
if descending[i] is True:
# ``np.searchsorted`` expects the array to be sorted in ascending
# order, so we pass ``sorter``, which is an array of integer indices
# that sort ``col_vals`` into ascending order. The returned index
# is an index into the ascending order of ``col_vals``, so we need
# to subtract it from ``len(col_vals)`` to get the index in the
# original descending order of ``col_vals``.
left = prevleft + (
len(col_vals)
- np.searchsorted(
Expand All @@ -767,10 +793,14 @@ def find_partition_index(
else:
left = prevleft + np.searchsorted(col_vals, desired_val, side="left")
right = prevleft + np.searchsorted(col_vals, desired_val, side="right")
return right if descending is True else left
return right if descending[0] is True else left


def find_partitions(table, boundaries, sort_key):
def find_partitions(
table: Union["pyarrow.Table", "pandas.DataFrame"],
boundaries: List[Tuple[Union[int, float]]],
sort_key: "SortKey",
):
partitions = []

# For each boundary value, count the number of items that are less
Expand Down
5 changes: 2 additions & 3 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2476,9 +2476,8 @@ def sort(
The `key` parameter must be specified (i.e., it cannot be `None`).
.. note::
The `descending` parameter must be a boolean, or a list of booleans.
If it is a list, all items in the list must share the same direction.
Multi-directional sort is not supported yet.
If provided, the `boundaries` parameter can only be used to partition
the first sort key.
Examples:
>>> import ray
Expand Down
12 changes: 9 additions & 3 deletions python/ray/data/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_sort_multiple_keys_produces_equally_sized_blocks(ray_start_regular):
[{"a": i, "b": j} for i in range(2) for j in range(5)], override_num_blocks=5
)

ds_sorted = ds.sort(["a", "b"])
ds_sorted = ds.sort(["a", "b"], descending=[False, True])

num_rows_per_block = [
bundle.num_rows() for bundle in ds_sorted.iter_internal_ref_bundles()
Expand Down Expand Up @@ -228,8 +228,14 @@ def test_sort_with_multiple_keys(ray_start_regular, descending, batch_format):
batch_format=batch_format,
batch_size=None,
)
df.sort_values(["a", "b", "c"], inplace=True, ascending=not descending)
sorted_ds = ds.repartition(num_blocks).sort(["a", "b", "c"], descending=descending)
df.sort_values(
["a", "b", "c"],
inplace=True,
ascending=[not descending, descending, not descending],
)
sorted_ds = ds.repartition(num_blocks).sort(
["a", "b", "c"], descending=[descending, not descending, descending]
)

# Number of blocks is preserved
assert len(sorted_ds._block_num_rows()) == num_blocks
Expand Down
149 changes: 149 additions & 0 deletions python/ray/data/tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, Optional

import numpy as np
import pyarrow as pa
import pytest
from typing_extensions import Hashable

Expand All @@ -12,10 +13,13 @@
trace_allocation,
trace_deallocation,
)
from ray.data._internal.planner.exchange.sort_task_spec import SortKey
from ray.data._internal.remote_fn import _make_hashable, cached_remote_fn
from ray.data._internal.util import (
NULL_SENTINEL,
_check_pyarrow_version,
find_partition_index,
find_partitions,
iterate_with_retry,
)
from ray.data.tests.conftest import * # noqa: F401, F403
Expand Down Expand Up @@ -220,6 +224,151 @@ def __next__(self):
assert list(iterate_with_retry(MockIterable, description="get item")) == [0, 1, 2]


def test_find_partition_index_single_column_ascending():
table = pa.table({"value": [1, 2, 2, 3, 5]})
sort_key = SortKey(key=["value"], descending=[False])
assert find_partition_index(table, (0,), sort_key) == 0 # all entries > 0
assert find_partition_index(table, (2,), sort_key) == 1 # first match index
assert find_partition_index(table, (4,), sort_key) == 4 # belongs after 3, before 5
assert find_partition_index(table, (6,), sort_key) == 5 # all entries < 6


def test_find_partition_index_single_column_descending():
table = pa.table({"value": [5, 3, 2, 2, 1]})
sort_key = SortKey(key=["value"], descending=[True])
assert find_partition_index(table, (6,), sort_key) == 0 # belongs before 5
assert find_partition_index(table, (3,), sort_key) == 2 # after the last 3
assert find_partition_index(table, (2,), sort_key) == 4 # after the last 2
assert find_partition_index(table, (0,), sort_key) == 5 # all entries > 0


def test_find_partition_index_multi_column():
# Table sorted by col1 asc, then col2 desc.
table = pa.table({"col1": [1, 1, 1, 2, 2], "col2": [3, 2, 1, 2, 1]})
sort_key = SortKey(key=["col1", "col2"], descending=[False, True])
# Insert value (1,3) -> belongs before (1,2)
assert find_partition_index(table, (1, 3), sort_key) == 0
# Insert value (1,2) -> belongs after the first (1,3) and before (1,2)
# because col1 ties, col2 descending
assert find_partition_index(table, (1, 2), sort_key) == 1
# Insert value (2,2) -> belongs right before (2,2) that starts at index 3
assert find_partition_index(table, (2, 2), sort_key) == 3
# Insert value (0, 4) -> belongs at index 0 (all col1 > 0)
assert find_partition_index(table, (0, 4), sort_key) == 0
# Insert value (2,0) -> belongs after (2,1)
assert find_partition_index(table, (2, 0), sort_key) == 5


def test_find_partition_index_with_nulls():
# _NullSentinel is sorted greater, so they appear after all real values.
table = pa.table({"value": [1, 2, 3, None, None]})
sort_key = SortKey(key=["value"], descending=[False])
# Insert (2,) -> belongs after 1, before 2 => index 1
# (But the actual find_partition_index uses the table as-is.)
assert find_partition_index(table, (2,), sort_key) == 1
# Insert (4,) -> belongs before any null => index 3
assert find_partition_index(table, (4,), sort_key) == 3
# Insert (None,) -> always belongs at the end
assert find_partition_index(table, (None,), sort_key) == 5


def test_find_partition_index_duplicates():
table = pa.table({"value": [2, 2, 2, 2, 2]})
sort_key = SortKey(key=["value"], descending=[False])
# Insert (2,) in a table of all 2's -> first matching index is 0
assert find_partition_index(table, (2,), sort_key) == 0
# Insert (1,) -> belongs at index 0
assert find_partition_index(table, (1,), sort_key) == 0
# Insert (3,) -> belongs at index 5
assert find_partition_index(table, (3,), sort_key) == 5


def test_find_partition_index_duplicates_descending():
table = pa.table({"value": [2, 2, 2, 2, 2]})
sort_key = SortKey(key=["value"], descending=[True])
# Insert (2,) in a table of all 2's -> belongs at index 5
assert find_partition_index(table, (2,), sort_key) == 5
# Insert (1,) -> belongs at index 5
assert find_partition_index(table, (1,), sort_key) == 5
# Insert (3,) -> belongs at index 0
assert find_partition_index(table, (3,), sort_key) == 0


def test_find_partitions_single_column_ascending():
table = pa.table({"value": [1, 2, 2, 3, 5]})
sort_key = SortKey(key=["value"], descending=[False])
boundaries = [(0,), (2,), (4,), (6,)]
partitions = find_partitions(table, boundaries, sort_key)
assert len(partitions) == 5
assert partitions[0].to_pydict() == {"value": []} # <0
assert partitions[1].to_pydict() == {"value": [1]} # [0,2)
assert partitions[2].to_pydict() == {"value": [2, 2, 3]} # [2,4)
assert partitions[3].to_pydict() == {"value": [5]} # [4,6)
assert partitions[4].to_pydict() == {"value": []} # >=6


def test_find_partitions_single_column_descending():
table = pa.table({"value": [5, 3, 2, 2, 1]})
sort_key = SortKey(key=["value"], descending=[True])
boundaries = [(6,), (3,), (2,), (0,)]
partitions = find_partitions(table, boundaries, sort_key)
assert len(partitions) == 5
assert partitions[0].to_pydict() == {"value": []} # >=6
assert partitions[1].to_pydict() == {"value": [5, 3]} # [3, 6)
assert partitions[2].to_pydict() == {"value": [2, 2]} # [2, 3)
assert partitions[3].to_pydict() == {"value": [1]} # [0, 2)
assert partitions[4].to_pydict() == {"value": []} # <0


def test_find_partitions_multi_column_ascending_first():
table = pa.table({"col1": [1, 1, 1, 1, 1, 2, 2], "col2": [4, 3, 2.5, 2, 1, 2, 1]})
sort_key = SortKey(key=["col1", "col2"], descending=[False, True])
boundaries = [(1, 3), (1, 2), (2, 2), (2, 0)]
partitions = find_partitions(table, boundaries, sort_key)
assert len(partitions) == 5
assert partitions[0].to_pydict() == {"col1": [1], "col2": [4]}
assert partitions[1].to_pydict() == {"col1": [1, 1], "col2": [3, 2.5]}
assert partitions[2].to_pydict() == {"col1": [1, 1], "col2": [2, 1]}
assert partitions[3].to_pydict() == {"col1": [2, 2], "col2": [2, 1]}
assert partitions[4].to_pydict() == {"col1": [], "col2": []}


def test_find_partitions_multi_column_descending_first():
table = pa.table({"col1": [2, 2, 1, 1, 1, 1, 1], "col2": [1, 2, 1, 2, 3, 4, 5]})
sort_key = SortKey(key=["col1", "col2"], descending=[True, False])
boundaries = [(2, 0), (2, 2), (1, 2), (1, 6)]
partitions = find_partitions(table, boundaries, sort_key)
assert len(partitions) == 5
assert partitions[0].to_pydict() == {"col1": [], "col2": []}
assert partitions[1].to_pydict() == {"col1": [2, 2], "col2": [1, 2]}
assert partitions[2].to_pydict() == {"col1": [1, 1], "col2": [1, 2]}
assert partitions[3].to_pydict() == {"col1": [1, 1, 1], "col2": [3, 4, 5]}
assert partitions[4].to_pydict() == {"col1": [], "col2": []}


def test_find_partitions_with_nulls():
table = pa.table({"value": [1, 2, 3, None, None]})
sort_key = SortKey(key=["value"], descending=[False])
boundaries = [(2,), (4,)]
partitions = find_partitions(table, boundaries, sort_key)
assert len(partitions) == 3
assert partitions[0].to_pydict() == {"value": [1]} # <2
assert partitions[1].to_pydict() == {"value": [2, 3]} # [2, 4)
assert partitions[2].to_pydict() == {"value": [None, None]} # >=4


def test_find_partitions_duplicates():
table = pa.table({"value": [2, 2, 2, 2, 2]})
sort_key = SortKey(key=["value"], descending=[False])
boundaries = [(1,), (2,), (3,)]
partitions = find_partitions(table, boundaries, sort_key)
assert len(partitions) == 4
assert partitions[0].to_pydict() == {"value": []} # <1
assert partitions[1].to_pydict() == {"value": []} # [1,2)
assert partitions[2].to_pydict() == {"value": [2, 2, 2, 2, 2]} # [2,3)
assert partitions[3].to_pydict() == {"value": []} # >=3


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit f186d61

Please sign in to comment.