diff --git a/dask_geopandas/backends.py b/dask_geopandas/backends.py index b73559d..60408a1 100644 --- a/dask_geopandas/backends.py +++ b/dask_geopandas/backends.py @@ -1,6 +1,7 @@ import uuid from packaging.version import Version +import dask from dask import config # Check if dask-dataframe is using dask-expr (default of None means True as well) @@ -84,3 +85,36 @@ def get_pyarrow_schema_geopandas(obj): for col in obj.columns[obj.dtypes == "geometry"]: df[col] = obj[col].to_wkb() return pa.Schema.from_pandas(df) + + +if Version(dask.__version__) >= Version("2023.6.1"): + from dask.dataframe.dispatch import ( + from_pyarrow_table_dispatch, + to_pyarrow_table_dispatch, + ) + + @to_pyarrow_table_dispatch.register((geopandas.GeoDataFrame,)) + def get_pyarrow_table_from_geopandas(obj, **kwargs): + # `kwargs` must be supported by `pyarrow.Table.from_pandas` + import pyarrow as pa + + if Version(geopandas.__version__).major < 1: + return pa.Table.from_pandas(obj.to_wkb(), **kwargs) + else: + # TODO handle kwargs? + return pa.table(obj.to_arrow()) + + @from_pyarrow_table_dispatch.register((geopandas.GeoDataFrame,)) + def get_geopandas_geodataframe_from_pyarrow(meta, table, **kwargs): + # `kwargs` must be supported by `pyarrow.Table.to_pandas` + if Version(geopandas.__version__).major < 1: + df = table.to_pandas(**kwargs) + + for col in meta.columns[meta.dtypes == "geometry"]: + df[col] = geopandas.GeoSeries.from_wkb(df[col], crs=meta[col].crs) + + return df + + else: + # TODO handle kwargs? + return geopandas.GeoDataFrame.from_arrow(table) diff --git a/dask_geopandas/tests/test_distributed.py b/dask_geopandas/tests/test_distributed.py new file mode 100644 index 0000000..9222df3 --- /dev/null +++ b/dask_geopandas/tests/test_distributed.py @@ -0,0 +1,38 @@ +from packaging.version import Version + +import geopandas + +import dask_geopandas + +import pytest +from geopandas.testing import assert_geodataframe_equal + +distributed = pytest.importorskip("distributed") + + +from distributed import Client, LocalCluster + + +@pytest.mark.skipif( + Version(distributed.__version__) < Version("2024.6.0"), + reason="distributed < 2024.6 has a wrong assertion", + # https://github.com/dask/distributed/pull/8667 +) +@pytest.mark.skipif( + Version(distributed.__version__) < Version("0.13"), + reason="geopandas < 0.13 does not implement sorting geometries", +) +def test_spatial_shuffle(naturalearth_cities): + df_points = geopandas.read_file(naturalearth_cities) + + with LocalCluster(n_workers=1) as cluster: + with Client(cluster): + ddf_points = dask_geopandas.from_geopandas(df_points, npartitions=4) + + ddf_result = ddf_points.spatial_shuffle( + by="hilbert", calculate_partitions=False + ) + result = ddf_result.compute() + + expected = df_points.sort_values("geometry").reset_index(drop=True) + assert_geodataframe_equal(result.reset_index(drop=True), expected)