Skip to content

Commit

Permalink
Bug: add conversion from slice to array for selecting rows in pyarrow…
Browse files Browse the repository at this point in the history
… `__getitem__` (#978)
  • Loading branch information
raisadz authored Sep 15, 2024
1 parent 77f505a commit 60ed72e
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 16 deletions.
22 changes: 18 additions & 4 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import overload

from narwhals._arrow.utils import broadcast_series
from narwhals._arrow.utils import convert_slice_to_nparray
from narwhals._arrow.utils import translate_dtype
from narwhals._arrow.utils import validate_dataframe_comparand
from narwhals._expression_parsing import evaluate_into_exprs
Expand Down Expand Up @@ -126,7 +127,8 @@ def __getitem__(
| slice
| Sequence[int]
| Sequence[str]
| tuple[Sequence[int], str | int],
| tuple[Sequence[int], str | int]
| tuple[slice, str | int],
) -> ArrowSeries | ArrowDataFrame:
if isinstance(item, str):
from narwhals._arrow.series import ArrowSeries
Expand All @@ -144,7 +146,10 @@ def __getitem__(
if item[0] == slice(None):
selected_rows = self._native_frame
else:
selected_rows = self._native_frame.take(item[0])
range_ = convert_slice_to_nparray(
num_rows=len(self._native_frame), rows_slice=item[0]
)
selected_rows = self._native_frame.take(range_)

return self._from_native_frame(selected_rows.select(item[1]))

Expand Down Expand Up @@ -174,13 +179,22 @@ def __getitem__(
)
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg) # pragma: no cover

from narwhals._arrow.series import ArrowSeries

# PyArrow columns are always strings
col_name = item[1] if isinstance(item[1], str) else self.columns[item[1]]
assert not isinstance(item[0], str) # help mypy # noqa: S101
if (isinstance(item[0], slice)) and (item[0] == slice(None)):
return ArrowSeries(
self._native_frame[col_name],
name=col_name,
backend_version=self._backend_version,
)
range_ = convert_slice_to_nparray(
num_rows=len(self._native_frame), rows_slice=item[0]
)
return ArrowSeries(
self._native_frame[col_name].take(item[0]),
self._native_frame[col_name].take(range_),
name=col_name,
backend_version=self._backend_version,
)
Expand Down
32 changes: 22 additions & 10 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence

from narwhals import dtypes
from narwhals.utils import isinstance_or_issubclass
Expand All @@ -11,7 +12,7 @@


def translate_dtype(dtype: Any) -> dtypes.DType:
import pyarrow as pa # ignore-banned-import()
import pyarrow as pa # ignore-banned-import

if pa.types.is_int64(dtype):
return dtypes.Int64()
Expand Down Expand Up @@ -55,7 +56,7 @@ def translate_dtype(dtype: Any) -> dtypes.DType:


def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any:
import pyarrow as pa # ignore-banned-import()
import pyarrow as pa # ignore-banned-import

from narwhals import dtypes

Expand Down Expand Up @@ -140,7 +141,7 @@ def validate_dataframe_comparand(
return NotImplemented
if isinstance(other, ArrowSeries):
if len(other) == 1:
import pyarrow as pa # ignore-banned-import()
import pyarrow as pa # ignore-banned-import

value = other.item()
if backend_version < (13,) and hasattr(value, "as_py"): # pragma: no cover
Expand All @@ -157,7 +158,7 @@ def horizontal_concat(dfs: list[Any]) -> Any:
Should be in namespace.
"""
import pyarrow as pa # ignore-banned-import()
import pyarrow as pa # ignore-banned-import

if not dfs:
msg = "No dataframes to concatenate" # pragma: no cover
Expand Down Expand Up @@ -190,16 +191,16 @@ def vertical_concat(dfs: list[Any]) -> Any:
msg = "unable to vstack, column names don't match"
raise TypeError(msg)

import pyarrow as pa # ignore-banned-import()
import pyarrow as pa # ignore-banned-import

return pa.concat_tables(dfs).combine_chunks()


def floordiv_compat(left: Any, right: Any) -> Any:
# The following lines are adapted from pandas' pyarrow implementation.
# Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154
import pyarrow as pa # ignore-banned-import()
import pyarrow.compute as pc # ignore-banned-import()
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

if isinstance(left, (int, float)):
left = pa.scalar(left)
Expand Down Expand Up @@ -237,8 +238,8 @@ def floordiv_compat(left: Any, right: Any) -> Any:
def cast_for_truediv(arrow_array: Any, pa_object: Any) -> tuple[Any, Any]:
# Lifted from:
# https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122
import pyarrow as pa # ignore-banned-import()
import pyarrow.compute as pc # ignore-banned-import()
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

# Ensure int / int -> float mirroring Python/Numpy behavior
# as pc.divide_checked(int, int) -> int
Expand All @@ -260,7 +261,7 @@ def broadcast_series(series: list[ArrowSeries]) -> list[Any]:
if fast_path:
return [s._native_series for s in series]

import pyarrow as pa # ignore-banned-import()
import pyarrow as pa # ignore-banned-import

reshaped = []
for s, length in zip(series, lengths):
Expand All @@ -274,3 +275,14 @@ def broadcast_series(series: list[ArrowSeries]) -> list[Any]:
reshaped.append(s_native)

return reshaped


def convert_slice_to_nparray(
num_rows: int, rows_slice: slice | int | Sequence[int]
) -> Any:
import numpy as np # ignore-banned-import

if isinstance(rows_slice, slice):
return np.arange(num_rows)[rows_slice]
else:
return rows_slice
5 changes: 5 additions & 0 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,11 +602,15 @@ def __getitem__(self, item: tuple[slice, Sequence[int]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[slice, str]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[slice, Sequence[str]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[slice, int]) -> Series: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self, item: Sequence[int]) -> Self: ...
Expand All @@ -627,6 +631,7 @@ def __getitem__(
| Sequence[int]
| Sequence[str]
| tuple[Sequence[int], str | int]
| tuple[slice, str | int]
| tuple[slice | Sequence[int], Sequence[int] | Sequence[str] | slice],
) -> Series | Self:
"""
Expand Down
8 changes: 6 additions & 2 deletions narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,25 @@ def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ...
def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[slice, Sequence[int]]) -> Self: ...

@overload
def __getitem__(self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[slice, str]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[slice, Sequence[str]]) -> Self: ...

@overload
def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[slice, int]) -> Series: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self, item: Sequence[int]) -> Self: ...

@overload
def __getitem__(self, item: str) -> Series: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self, item: Sequence[str]) -> Self: ...

Expand Down
3 changes: 3 additions & 0 deletions tests/frame/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ def test_slice_slice_columns(constructor_eager: Any) -> None:
result = df[:, [0, 2]]
expected = {"a": [1, 2, 3], "c": [7, 8, 9]}
compare_dicts(result, expected)
result = df[:2, [0, 2]]
expected = {"a": [1, 2], "c": [7, 8]}
compare_dicts(result, expected)
result = df[["b", "c"]]
expected = {"b": [4, 5, 6], "c": [7, 8, 9]}
compare_dicts(result, expected)
Expand Down
12 changes: 12 additions & 0 deletions tests/series_only/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,15 @@ def test_slice(constructor_eager: Any) -> None:
result = {"a": df["a"][1:]}
expected = {"a": [2, 3]}
compare_dicts(result, expected)
result = {"b": df[:, 1]}
expected = {"b": [4, 5, 6]}
compare_dicts(result, expected)
result = {"b": df[:, "b"]}
expected = {"b": [4, 5, 6]}
compare_dicts(result, expected)
result = {"b": df[:2, "b"]}
expected = {"b": [4, 5]}
compare_dicts(result, expected)
result = {"b": df[:2, 1]}
expected = {"b": [4, 5]}
compare_dicts(result, expected)

0 comments on commit 60ed72e

Please sign in to comment.