Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: add conversion from slice to array for selecting rows in pyarrow __getitem__ #978

Merged
merged 2 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import overload

from narwhals._arrow.utils import broadcast_series
from narwhals._arrow.utils import convert_slice_to_nparray
from narwhals._arrow.utils import translate_dtype
from narwhals._arrow.utils import validate_dataframe_comparand
from narwhals._expression_parsing import evaluate_into_exprs
Expand Down Expand Up @@ -126,7 +127,8 @@ def __getitem__(
| slice
| Sequence[int]
| Sequence[str]
| tuple[Sequence[int], str | int],
| tuple[Sequence[int], str | int]
| tuple[slice, str | int],
) -> ArrowSeries | ArrowDataFrame:
if isinstance(item, str):
from narwhals._arrow.series import ArrowSeries
Expand All @@ -144,7 +146,10 @@ def __getitem__(
if item[0] == slice(None):
selected_rows = self._native_frame
else:
selected_rows = self._native_frame.take(item[0])
range_ = convert_slice_to_nparray(
num_rows=len(self._native_frame), rows_slice=item[0]
)
selected_rows = self._native_frame.take(range_)

return self._from_native_frame(selected_rows.select(item[1]))

Expand Down Expand Up @@ -174,13 +179,22 @@ def __getitem__(
)
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg) # pragma: no cover

from narwhals._arrow.series import ArrowSeries

# PyArrow columns are always strings
col_name = item[1] if isinstance(item[1], str) else self.columns[item[1]]
assert not isinstance(item[0], str) # help mypy # noqa: S101
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fancy changing this from assert to a raise statement as a follow up?

the dangers of assert in python

TL;DR: assert is only run in debug mode, if you run python with the optimized flag, that will be skipped.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment @FBruzzesi , I replaced assert with raise in this PR #980

Please, note that the line never gets hit so # pragma: no cover is added

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, thanks for pointing that out πŸ‘Œ! Let's move the conversation there anyway

if (isinstance(item[0], slice)) and (item[0] == slice(None)):
return ArrowSeries(
self._native_frame[col_name],
name=col_name,
backend_version=self._backend_version,
)
range_ = convert_slice_to_nparray(
num_rows=len(self._native_frame), rows_slice=item[0]
)
return ArrowSeries(
self._native_frame[col_name].take(item[0]),
self._native_frame[col_name].take(range_),
name=col_name,
backend_version=self._backend_version,
)
Expand Down
32 changes: 22 additions & 10 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence

from narwhals import dtypes
from narwhals.utils import isinstance_or_issubclass
Expand All @@ -11,7 +12,7 @@


def translate_dtype(dtype: Any) -> dtypes.DType:
import pyarrow as pa # ignore-banned-import()
import pyarrow as pa # ignore-banned-import

if pa.types.is_int64(dtype):
return dtypes.Int64()
Expand Down Expand Up @@ -55,7 +56,7 @@ def translate_dtype(dtype: Any) -> dtypes.DType:


def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any:
import pyarrow as pa # ignore-banned-import()
import pyarrow as pa # ignore-banned-import

from narwhals import dtypes

Expand Down Expand Up @@ -140,7 +141,7 @@ def validate_dataframe_comparand(
return NotImplemented
if isinstance(other, ArrowSeries):
if len(other) == 1:
import pyarrow as pa # ignore-banned-import()
import pyarrow as pa # ignore-banned-import

value = other.item()
if backend_version < (13,) and hasattr(value, "as_py"): # pragma: no cover
Expand All @@ -157,7 +158,7 @@ def horizontal_concat(dfs: list[Any]) -> Any:

Should be in namespace.
"""
import pyarrow as pa # ignore-banned-import()
import pyarrow as pa # ignore-banned-import

if not dfs:
msg = "No dataframes to concatenate" # pragma: no cover
Expand Down Expand Up @@ -190,16 +191,16 @@ def vertical_concat(dfs: list[Any]) -> Any:
msg = "unable to vstack, column names don't match"
raise TypeError(msg)

import pyarrow as pa # ignore-banned-import()
import pyarrow as pa # ignore-banned-import

return pa.concat_tables(dfs).combine_chunks()


def floordiv_compat(left: Any, right: Any) -> Any:
# The following lines are adapted from pandas' pyarrow implementation.
# Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154
import pyarrow as pa # ignore-banned-import()
import pyarrow.compute as pc # ignore-banned-import()
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

if isinstance(left, (int, float)):
left = pa.scalar(left)
Expand Down Expand Up @@ -237,8 +238,8 @@ def floordiv_compat(left: Any, right: Any) -> Any:
def cast_for_truediv(arrow_array: Any, pa_object: Any) -> tuple[Any, Any]:
# Lifted from:
# https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122
import pyarrow as pa # ignore-banned-import()
import pyarrow.compute as pc # ignore-banned-import()
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

# Ensure int / int -> float mirroring Python/Numpy behavior
# as pc.divide_checked(int, int) -> int
Expand All @@ -260,7 +261,7 @@ def broadcast_series(series: list[ArrowSeries]) -> list[Any]:
if fast_path:
return [s._native_series for s in series]

import pyarrow as pa # ignore-banned-import()
import pyarrow as pa # ignore-banned-import

reshaped = []
for s, length in zip(series, lengths):
Expand All @@ -274,3 +275,14 @@ def broadcast_series(series: list[ArrowSeries]) -> list[Any]:
reshaped.append(s_native)

return reshaped


def convert_slice_to_nparray(
num_rows: int, rows_slice: slice | int | Sequence[int]
) -> Any:
import numpy as np # ignore-banned-import

if isinstance(rows_slice, slice):
return np.arange(num_rows)[rows_slice]
else:
return rows_slice
5 changes: 5 additions & 0 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,11 +602,15 @@ def __getitem__(self, item: tuple[slice, Sequence[int]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[slice, str]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[slice, Sequence[str]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[slice, int]) -> Series: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self, item: Sequence[int]) -> Self: ...
Expand All @@ -627,6 +631,7 @@ def __getitem__(
| Sequence[int]
| Sequence[str]
| tuple[Sequence[int], str | int]
| tuple[slice, str | int]
| tuple[slice | Sequence[int], Sequence[int] | Sequence[str] | slice],
) -> Series | Self:
"""
Expand Down
8 changes: 6 additions & 2 deletions narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,25 @@ def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ...
def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[slice, Sequence[int]]) -> Self: ...

@overload
def __getitem__(self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[slice, str]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[slice, Sequence[str]]) -> Self: ...

@overload
def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[slice, int]) -> Series: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self, item: Sequence[int]) -> Self: ...

@overload
def __getitem__(self, item: str) -> Series: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self, item: Sequence[str]) -> Self: ...

Expand Down
3 changes: 3 additions & 0 deletions tests/frame/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ def test_slice_slice_columns(constructor_eager: Any) -> None:
result = df[:, [0, 2]]
expected = {"a": [1, 2, 3], "c": [7, 8, 9]}
compare_dicts(result, expected)
result = df[:2, [0, 2]]
expected = {"a": [1, 2], "c": [7, 8]}
compare_dicts(result, expected)
result = df[["b", "c"]]
expected = {"b": [4, 5, 6], "c": [7, 8, 9]}
compare_dicts(result, expected)
Expand Down
12 changes: 12 additions & 0 deletions tests/series_only/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,15 @@ def test_slice(constructor_eager: Any) -> None:
result = {"a": df["a"][1:]}
expected = {"a": [2, 3]}
compare_dicts(result, expected)
result = {"b": df[:, 1]}
expected = {"b": [4, 5, 6]}
compare_dicts(result, expected)
result = {"b": df[:, "b"]}
expected = {"b": [4, 5, 6]}
compare_dicts(result, expected)
result = {"b": df[:2, "b"]}
expected = {"b": [4, 5]}
compare_dicts(result, expected)
result = {"b": df[:2, 1]}
expected = {"b": [4, 5]}
compare_dicts(result, expected)
Loading