Skip to content

Commit

Permalink
Bump numpy>=1.24 and pyarrow>=14.0.1 minimum versions (#8837)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau authored Aug 21, 2024
1 parent 30e01fb commit 5bbceb7
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 63 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,24 @@ jobs:
- os: ubuntu-latest
environment: mindeps
label: numpy
extra_packages: [numpy=1.21]
extra_packages: [numpy=1.24]
partition: "ci1"
- os: ubuntu-latest
environment: mindeps
label: numpy
extra_packages: [numpy=1.21]
extra_packages: [numpy=1.24]
partition: "not ci1"

# dask.dataframe P2P shuffle
- os: ubuntu-latest
environment: mindeps
label: pandas
extra_packages: [numpy=1.21, pandas=2.0, pyarrow=7, pyarrow-hotfix]
extra_packages: [numpy=1.24, pandas=2.0, pyarrow=14.0.1]
partition: "ci1"
- os: ubuntu-latest
environment: mindeps
label: pandas
extra_packages: [numpy=1.21, pandas=2.0, pyarrow=7, pyarrow-hotfix]
extra_packages: [numpy=1.24, pandas=2.0, pyarrow=14.0.1]
partition: "not ci1"

- os: ubuntu-latest
Expand Down
28 changes: 3 additions & 25 deletions distributed/shuffle/_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def check_minimal_arrow_version() -> None:
Raises a ModuleNotFoundError if pyarrow is not installed or an
ImportError if the installed version is not recent enough.
"""
minversion = "7.0.0"
minversion = "14.0.1"
try:
import pyarrow as pa
except ModuleNotFoundError:
Expand All @@ -52,14 +52,7 @@ def check_minimal_arrow_version() -> None:
def concat_tables(tables: Iterable[pa.Table]) -> pa.Table:
import pyarrow as pa

if parse(pa.__version__) >= parse("14.0.0"):
return pa.concat_tables(tables, promote_options="permissive")
try:
return pa.concat_tables(tables, promote=True)
except pa.ArrowNotImplementedError as e:
if parse(pa.__version__) >= parse("12.0.0"):
raise e
raise
return pa.concat_tables(tables, promote_options="permissive")


def convert_shards(
Expand Down Expand Up @@ -179,23 +172,8 @@ def read_from_disk(path: Path) -> tuple[list[pa.Table], int]:
return shards, size


def concat_arrays(arrays: Iterable[pa.Array]) -> pa.Array:
import pyarrow as pa

try:
return pa.concat_arrays(arrays)
except pa.ArrowNotImplementedError as e:
if parse(pa.__version__) >= parse("12.0.0"):
raise
if e.args[0].startswith("concatenation of extension"):
raise RuntimeError(
"P2P shuffling requires pyarrow>=12.0.0 to support extension types."
) from e
raise


def _copy_table(table: pa.Table) -> pa.Table:
import pyarrow as pa

arrs = [concat_arrays(column.chunks) for column in table.columns]
arrs = [pa.concat_arrays(column.chunks) for column in table.columns]
return pa.table(data=arrs, schema=table.schema)
5 changes: 1 addition & 4 deletions distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import warnings

import pytest
from packaging.version import parse as parse_version

np = pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")
Expand Down Expand Up @@ -37,8 +36,6 @@
from distributed.shuffle.tests.utils import AbstractShuffleTestPool
from distributed.utils_test import async_poll_for, gen_cluster, gen_test

NUMPY_GE_124 = parse_version(np.__version__) >= parse_version("1.24")


class ArrayRechunkTestPool(AbstractShuffleTestPool):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -175,7 +172,7 @@ async def test_lowlevel_rechunk(tmp_path, n_workers, barrier_first_worker, disk)
np.testing.assert_array_equal(
concatenate3(old_cs.tolist()),
concatenate3(all_chunks.tolist()),
**({"strict": True} if NUMPY_GE_124 else {}),
strict=True,
)


Expand Down
39 changes: 9 additions & 30 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from unittest import mock

import pytest
from packaging.version import parse as parse_version
from tornado.ioloop import IOLoop

import dask
Expand Down Expand Up @@ -73,13 +72,8 @@

try:
import pyarrow as pa

PYARROW_GE_12 = parse_version(pa.__version__).release >= (12,)
PYARROW_GE_14 = parse_version(pa.__version__).release >= (14,)
except ImportError:
pa = None
PYARROW_GE_12 = False
PYARROW_GE_14 = False


@pytest.fixture(params=[0, 0.3, 1], ids=["none", "some", "all"])
Expand Down Expand Up @@ -1145,6 +1139,9 @@ def __init__(self, value: int) -> None:
),
f"col{next(counter)}": pd.array(["x", "y"] * 50, dtype="category"),
f"col{next(counter)}": pd.array(["lorem ipsum"] * 100, dtype="string"),
# Extension types
f"col{next(counter)}": pd.period_range("2022-01-01", periods=100, freq="D"),
f"col{next(counter)}": pd.interval_range(start=0, end=100, freq=1),
# FIXME: PyArrow does not support sparse data:
# https://issues.apache.org/jira/browse/ARROW-8679
# f"col{next(counter)}": pd.array(
Expand All @@ -1158,17 +1155,6 @@ def __init__(self, value: int) -> None:
# ),
}

if PYARROW_GE_12:
columns.update(
{
# Extension types
f"col{next(counter)}": pd.period_range(
"2022-01-01", periods=100, freq="D"
),
f"col{next(counter)}": pd.interval_range(start=0, end=100, freq=1),
}
)

columns.update(
{
# PyArrow dtypes
Expand Down Expand Up @@ -2502,18 +2488,11 @@ def make_partition(i):
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
out = ddf.shuffle(on="a", ignore_index=True)

if PYARROW_GE_14:
result, expected = c.compute([ddf, out])
result = await result
expected = await expected
dd.assert_eq(result, expected)
del result
else:
with raises_with_cause(
RuntimeError, r"shuffling \w+ failed", pa.ArrowInvalid, "incompatible types"
):
await c.compute(out)
await c.close()
result, expected = c.compute([ddf, out])
result = await result
expected = await expected
dd.assert_eq(result, expected)
del result
del out

await assert_worker_cleanup(a)
Expand All @@ -2536,7 +2515,7 @@ def make_partition(i):
with raises_with_cause(
RuntimeError,
r"(shuffling \w*|shuffle_barrier) failed",
pa.ArrowTypeError if PYARROW_GE_14 else pa.ArrowInvalid,
pa.ArrowTypeError,
"incompatible types",
):
await c.compute(out)
Expand Down

0 comments on commit 5bbceb7

Please sign in to comment.