From f186d61cd4c2bbb4b2dcd48fe486d03db09c017a Mon Sep 17 00:00:00 2001 From: Sumanth Ratna Date: Tue, 24 Dec 2024 12:45:29 -0500 Subject: [PATCH] [data] feat: Implement multi-directional sort over Ray Data datasets (#49281) --- .../planner/exchange/sort_task_spec.py | 28 ++-- python/ray/data/_internal/planner/sort.py | 8 +- python/ray/data/_internal/util.py | 38 ++++- python/ray/data/dataset.py | 5 +- python/ray/data/tests/test_sort.py | 12 +- python/ray/data/tests/test_util.py | 149 ++++++++++++++++++ 6 files changed, 214 insertions(+), 26 deletions(-) diff --git a/python/ray/data/_internal/planner/exchange/sort_task_spec.py b/python/ray/data/_internal/planner/exchange/sort_task_spec.py index 7c67b3dbdefe..5f79a7885cbf 100644 --- a/python/ray/data/_internal/planner/exchange/sort_task_spec.py +++ b/python/ray/data/_internal/planner/exchange/sort_task_spec.py @@ -41,8 +41,6 @@ def __init__( raise ValueError( "Length of `descending` does not match the length of the key." ) - if len(set(descending)) != 1: - raise ValueError("Sorting with mixed key orders not supported yet.") self._columns = key self._descending = descending if boundaries: @@ -58,17 +56,17 @@ def __init__( def get_columns(self) -> List[str]: return self._columns - def get_descending(self) -> bool: - return self._descending[0] + def get_descending(self) -> List[bool]: + return self._descending def to_arrow_sort_args(self) -> List[Tuple[str, str]]: return [ - (key, "descending" if self._descending[0] else "ascending") - for key in self._columns + (key, "descending" if desc else "ascending") + for key, desc in zip(self._columns, self._descending) ] - def to_pandas_sort_args(self) -> Tuple[List[str], bool]: - return self._columns, not self._descending[0] + def to_pandas_sort_args(self) -> Tuple[List[str], List[bool]]: + return self._columns, [not desc for desc in self._descending] def validate_schema(self, schema: Optional[Union[type, "pyarrow.lib.Schema"]]): """Check the key function is valid on the given schema.""" @@ -209,11 +207,15 @@ def is_na(x): return np.isnan(x) return False - def key_fn_with_nones(sample): - return tuple(NULL_SENTINEL if is_na(x) else x for x in sample) - - # Sort the list, but Nones should be NULL_SENTINEL to ensure safe sorting. - samples_list = sorted(samples_list, key=key_fn_with_nones) + # To allow multi-directional sort, we utilize Python's stable sort: we + # sort several times with different directions. We do this in reverse, so + # that the last key we sort by is the primary sort key passed by the user. + for i, desc in list(enumerate(sort_key.get_descending()))[::-1]: + # Sort the list, but Nones should be NULL_SENTINEL to ensure safe sorting. + samples_list.sort( + key=lambda sample: NULL_SENTINEL if is_na(sample[i]) else sample[i], + reverse=desc, + ) # Each boundary corresponds to a quantile of the data. quantile_indices = [ diff --git a/python/ray/data/_internal/planner/sort.py b/python/ray/data/_internal/planner/sort.py index 1a14e9f260ae..ffb936d74bea 100644 --- a/python/ray/data/_internal/planner/sort.py +++ b/python/ray/data/_internal/planner/sort.py @@ -52,11 +52,13 @@ def fn( blocks, sort_key, num_outputs, sample_bar ) else: + # For user-specified boundaries (which only partition by the primary + # sort key), reverse `boundaries` so that the partitions are produced + # in descending order, as desired. boundaries = [(b,) for b in sort_key.boundaries] + if sort_key.get_descending()[0]: + boundaries = boundaries[::-1] num_outputs = len(boundaries) + 1 - _, ascending = sort_key.to_pandas_sort_args() - if not ascending: - boundaries.reverse() sort_spec = SortTaskSpec( boundaries=boundaries, sort_key=sort_key, batch_format=batch_format ) diff --git a/python/ray/data/_internal/util.py b/python/ray/data/_internal/util.py index 51757fdcc32e..0c4969b48e8c 100644 --- a/python/ray/data/_internal/util.py +++ b/python/ray/data/_internal/util.py @@ -720,9 +720,29 @@ def unify_block_metadata_schema( def find_partition_index( table: Union["pyarrow.Table", "pandas.DataFrame"], - desired: List[Any], + desired: Tuple[Union[int, float]], sort_key: "SortKey", ) -> int: + """For the given block, find the index where the desired value should be + added, to maintain sorted order. + + We do this by iterating over each column, starting with the primary sort key, + and binary searching for the desired value in the column. Each binary search + shortens the "range" of indices (represented by ``left`` and ``right``, which + are indices of rows) where the desired value could be inserted. + + Args: + table: The block to search in. + desired: A single tuple representing the boundary to partition at. + ``len(desired)`` must be less than or equal to the number of columns + being sorted. + sort_key: The sort key to use for sorting, providing the columns to be + sorted and their directions. + + Returns: + The index where the desired value should be inserted to maintain sorted + order. + """ columns = sort_key.get_columns() descending = sort_key.get_descending() @@ -745,7 +765,13 @@ def find_partition_index( col_vals[null_mask] = NULL_SENTINEL prevleft = left - if descending is True: + if descending[i] is True: + # ``np.searchsorted`` expects the array to be sorted in ascending + # order, so we pass ``sorter``, which is an array of integer indices + # that sort ``col_vals`` into ascending order. The returned index + # is an index into the ascending order of ``col_vals``, so we need + # to subtract it from ``len(col_vals)`` to get the index in the + # original descending order of ``col_vals``. left = prevleft + ( len(col_vals) - np.searchsorted( @@ -767,10 +793,14 @@ def find_partition_index( else: left = prevleft + np.searchsorted(col_vals, desired_val, side="left") right = prevleft + np.searchsorted(col_vals, desired_val, side="right") - return right if descending is True else left + return right if descending[0] is True else left -def find_partitions(table, boundaries, sort_key): +def find_partitions( + table: Union["pyarrow.Table", "pandas.DataFrame"], + boundaries: List[Tuple[Union[int, float]]], + sort_key: "SortKey", +): partitions = [] # For each boundary value, count the number of items that are less diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 0f11d7ca66ba..40faf6d9bd2f 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2476,9 +2476,8 @@ def sort( The `key` parameter must be specified (i.e., it cannot be `None`). .. note:: - The `descending` parameter must be a boolean, or a list of booleans. - If it is a list, all items in the list must share the same direction. - Multi-directional sort is not supported yet. + If provided, the `boundaries` parameter can only be used to partition + the first sort key. Examples: >>> import ray diff --git a/python/ray/data/tests/test_sort.py b/python/ray/data/tests/test_sort.py index 982cf99409d4..78a2838d4deb 100644 --- a/python/ray/data/tests/test_sort.py +++ b/python/ray/data/tests/test_sort.py @@ -57,7 +57,7 @@ def test_sort_multiple_keys_produces_equally_sized_blocks(ray_start_regular): [{"a": i, "b": j} for i in range(2) for j in range(5)], override_num_blocks=5 ) - ds_sorted = ds.sort(["a", "b"]) + ds_sorted = ds.sort(["a", "b"], descending=[False, True]) num_rows_per_block = [ bundle.num_rows() for bundle in ds_sorted.iter_internal_ref_bundles() @@ -228,8 +228,14 @@ def test_sort_with_multiple_keys(ray_start_regular, descending, batch_format): batch_format=batch_format, batch_size=None, ) - df.sort_values(["a", "b", "c"], inplace=True, ascending=not descending) - sorted_ds = ds.repartition(num_blocks).sort(["a", "b", "c"], descending=descending) + df.sort_values( + ["a", "b", "c"], + inplace=True, + ascending=[not descending, descending, not descending], + ) + sorted_ds = ds.repartition(num_blocks).sort( + ["a", "b", "c"], descending=[descending, not descending, descending] + ) # Number of blocks is preserved assert len(sorted_ds._block_num_rows()) == num_blocks diff --git a/python/ray/data/tests/test_util.py b/python/ray/data/tests/test_util.py index 51a4a769e35a..7c33eb27de92 100644 --- a/python/ray/data/tests/test_util.py +++ b/python/ray/data/tests/test_util.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Optional import numpy as np +import pyarrow as pa import pytest from typing_extensions import Hashable @@ -12,10 +13,13 @@ trace_allocation, trace_deallocation, ) +from ray.data._internal.planner.exchange.sort_task_spec import SortKey from ray.data._internal.remote_fn import _make_hashable, cached_remote_fn from ray.data._internal.util import ( NULL_SENTINEL, _check_pyarrow_version, + find_partition_index, + find_partitions, iterate_with_retry, ) from ray.data.tests.conftest import * # noqa: F401, F403 @@ -220,6 +224,151 @@ def __next__(self): assert list(iterate_with_retry(MockIterable, description="get item")) == [0, 1, 2] +def test_find_partition_index_single_column_ascending(): + table = pa.table({"value": [1, 2, 2, 3, 5]}) + sort_key = SortKey(key=["value"], descending=[False]) + assert find_partition_index(table, (0,), sort_key) == 0 # all entries > 0 + assert find_partition_index(table, (2,), sort_key) == 1 # first match index + assert find_partition_index(table, (4,), sort_key) == 4 # belongs after 3, before 5 + assert find_partition_index(table, (6,), sort_key) == 5 # all entries < 6 + + +def test_find_partition_index_single_column_descending(): + table = pa.table({"value": [5, 3, 2, 2, 1]}) + sort_key = SortKey(key=["value"], descending=[True]) + assert find_partition_index(table, (6,), sort_key) == 0 # belongs before 5 + assert find_partition_index(table, (3,), sort_key) == 2 # after the last 3 + assert find_partition_index(table, (2,), sort_key) == 4 # after the last 2 + assert find_partition_index(table, (0,), sort_key) == 5 # all entries > 0 + + +def test_find_partition_index_multi_column(): + # Table sorted by col1 asc, then col2 desc. + table = pa.table({"col1": [1, 1, 1, 2, 2], "col2": [3, 2, 1, 2, 1]}) + sort_key = SortKey(key=["col1", "col2"], descending=[False, True]) + # Insert value (1,3) -> belongs before (1,2) + assert find_partition_index(table, (1, 3), sort_key) == 0 + # Insert value (1,2) -> belongs after the first (1,3) and before (1,2) + # because col1 ties, col2 descending + assert find_partition_index(table, (1, 2), sort_key) == 1 + # Insert value (2,2) -> belongs right before (2,2) that starts at index 3 + assert find_partition_index(table, (2, 2), sort_key) == 3 + # Insert value (0, 4) -> belongs at index 0 (all col1 > 0) + assert find_partition_index(table, (0, 4), sort_key) == 0 + # Insert value (2,0) -> belongs after (2,1) + assert find_partition_index(table, (2, 0), sort_key) == 5 + + +def test_find_partition_index_with_nulls(): + # _NullSentinel is sorted greater, so they appear after all real values. + table = pa.table({"value": [1, 2, 3, None, None]}) + sort_key = SortKey(key=["value"], descending=[False]) + # Insert (2,) -> belongs after 1, before 2 => index 1 + # (But the actual find_partition_index uses the table as-is.) + assert find_partition_index(table, (2,), sort_key) == 1 + # Insert (4,) -> belongs before any null => index 3 + assert find_partition_index(table, (4,), sort_key) == 3 + # Insert (None,) -> always belongs at the end + assert find_partition_index(table, (None,), sort_key) == 5 + + +def test_find_partition_index_duplicates(): + table = pa.table({"value": [2, 2, 2, 2, 2]}) + sort_key = SortKey(key=["value"], descending=[False]) + # Insert (2,) in a table of all 2's -> first matching index is 0 + assert find_partition_index(table, (2,), sort_key) == 0 + # Insert (1,) -> belongs at index 0 + assert find_partition_index(table, (1,), sort_key) == 0 + # Insert (3,) -> belongs at index 5 + assert find_partition_index(table, (3,), sort_key) == 5 + + +def test_find_partition_index_duplicates_descending(): + table = pa.table({"value": [2, 2, 2, 2, 2]}) + sort_key = SortKey(key=["value"], descending=[True]) + # Insert (2,) in a table of all 2's -> belongs at index 5 + assert find_partition_index(table, (2,), sort_key) == 5 + # Insert (1,) -> belongs at index 5 + assert find_partition_index(table, (1,), sort_key) == 5 + # Insert (3,) -> belongs at index 0 + assert find_partition_index(table, (3,), sort_key) == 0 + + +def test_find_partitions_single_column_ascending(): + table = pa.table({"value": [1, 2, 2, 3, 5]}) + sort_key = SortKey(key=["value"], descending=[False]) + boundaries = [(0,), (2,), (4,), (6,)] + partitions = find_partitions(table, boundaries, sort_key) + assert len(partitions) == 5 + assert partitions[0].to_pydict() == {"value": []} # <0 + assert partitions[1].to_pydict() == {"value": [1]} # [0,2) + assert partitions[2].to_pydict() == {"value": [2, 2, 3]} # [2,4) + assert partitions[3].to_pydict() == {"value": [5]} # [4,6) + assert partitions[4].to_pydict() == {"value": []} # >=6 + + +def test_find_partitions_single_column_descending(): + table = pa.table({"value": [5, 3, 2, 2, 1]}) + sort_key = SortKey(key=["value"], descending=[True]) + boundaries = [(6,), (3,), (2,), (0,)] + partitions = find_partitions(table, boundaries, sort_key) + assert len(partitions) == 5 + assert partitions[0].to_pydict() == {"value": []} # >=6 + assert partitions[1].to_pydict() == {"value": [5, 3]} # [3, 6) + assert partitions[2].to_pydict() == {"value": [2, 2]} # [2, 3) + assert partitions[3].to_pydict() == {"value": [1]} # [0, 2) + assert partitions[4].to_pydict() == {"value": []} # <0 + + +def test_find_partitions_multi_column_ascending_first(): + table = pa.table({"col1": [1, 1, 1, 1, 1, 2, 2], "col2": [4, 3, 2.5, 2, 1, 2, 1]}) + sort_key = SortKey(key=["col1", "col2"], descending=[False, True]) + boundaries = [(1, 3), (1, 2), (2, 2), (2, 0)] + partitions = find_partitions(table, boundaries, sort_key) + assert len(partitions) == 5 + assert partitions[0].to_pydict() == {"col1": [1], "col2": [4]} + assert partitions[1].to_pydict() == {"col1": [1, 1], "col2": [3, 2.5]} + assert partitions[2].to_pydict() == {"col1": [1, 1], "col2": [2, 1]} + assert partitions[3].to_pydict() == {"col1": [2, 2], "col2": [2, 1]} + assert partitions[4].to_pydict() == {"col1": [], "col2": []} + + +def test_find_partitions_multi_column_descending_first(): + table = pa.table({"col1": [2, 2, 1, 1, 1, 1, 1], "col2": [1, 2, 1, 2, 3, 4, 5]}) + sort_key = SortKey(key=["col1", "col2"], descending=[True, False]) + boundaries = [(2, 0), (2, 2), (1, 2), (1, 6)] + partitions = find_partitions(table, boundaries, sort_key) + assert len(partitions) == 5 + assert partitions[0].to_pydict() == {"col1": [], "col2": []} + assert partitions[1].to_pydict() == {"col1": [2, 2], "col2": [1, 2]} + assert partitions[2].to_pydict() == {"col1": [1, 1], "col2": [1, 2]} + assert partitions[3].to_pydict() == {"col1": [1, 1, 1], "col2": [3, 4, 5]} + assert partitions[4].to_pydict() == {"col1": [], "col2": []} + + +def test_find_partitions_with_nulls(): + table = pa.table({"value": [1, 2, 3, None, None]}) + sort_key = SortKey(key=["value"], descending=[False]) + boundaries = [(2,), (4,)] + partitions = find_partitions(table, boundaries, sort_key) + assert len(partitions) == 3 + assert partitions[0].to_pydict() == {"value": [1]} # <2 + assert partitions[1].to_pydict() == {"value": [2, 3]} # [2, 4) + assert partitions[2].to_pydict() == {"value": [None, None]} # >=4 + + +def test_find_partitions_duplicates(): + table = pa.table({"value": [2, 2, 2, 2, 2]}) + sort_key = SortKey(key=["value"], descending=[False]) + boundaries = [(1,), (2,), (3,)] + partitions = find_partitions(table, boundaries, sort_key) + assert len(partitions) == 4 + assert partitions[0].to_pydict() == {"value": []} # <1 + assert partitions[1].to_pydict() == {"value": []} # [1,2) + assert partitions[2].to_pydict() == {"value": [2, 2, 2, 2, 2]} # [2,3) + assert partitions[3].to_pydict() == {"value": []} # >=3 + + if __name__ == "__main__": import sys