diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 7029cc0e2..e3ec3da79 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -9,6 +9,7 @@ from typing import overload from narwhals._arrow.utils import broadcast_series +from narwhals._arrow.utils import convert_slice_to_nparray from narwhals._arrow.utils import translate_dtype from narwhals._arrow.utils import validate_dataframe_comparand from narwhals._expression_parsing import evaluate_into_exprs @@ -126,7 +127,8 @@ def __getitem__( | slice | Sequence[int] | Sequence[str] - | tuple[Sequence[int], str | int], + | tuple[Sequence[int], str | int] + | tuple[slice, str | int], ) -> ArrowSeries | ArrowDataFrame: if isinstance(item, str): from narwhals._arrow.series import ArrowSeries @@ -144,7 +146,10 @@ def __getitem__( if item[0] == slice(None): selected_rows = self._native_frame else: - selected_rows = self._native_frame.take(item[0]) + range_ = convert_slice_to_nparray( + num_rows=len(self._native_frame), rows_slice=item[0] + ) + selected_rows = self._native_frame.take(range_) return self._from_native_frame(selected_rows.select(item[1])) @@ -174,13 +179,22 @@ def __getitem__( ) msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover raise TypeError(msg) # pragma: no cover - from narwhals._arrow.series import ArrowSeries # PyArrow columns are always strings col_name = item[1] if isinstance(item[1], str) else self.columns[item[1]] + assert not isinstance(item[0], str) # help mypy # noqa: S101 + if (isinstance(item[0], slice)) and (item[0] == slice(None)): + return ArrowSeries( + self._native_frame[col_name], + name=col_name, + backend_version=self._backend_version, + ) + range_ = convert_slice_to_nparray( + num_rows=len(self._native_frame), rows_slice=item[0] + ) return ArrowSeries( - self._native_frame[col_name].take(item[0]), + self._native_frame[col_name].take(range_), name=col_name, backend_version=self._backend_version, ) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 673584607..a2a45586b 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Sequence from narwhals import dtypes from narwhals.utils import isinstance_or_issubclass @@ -11,7 +12,7 @@ def translate_dtype(dtype: Any) -> dtypes.DType: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa # ignore-banned-import if pa.types.is_int64(dtype): return dtypes.Int64() @@ -55,7 +56,7 @@ def translate_dtype(dtype: Any) -> dtypes.DType: def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa # ignore-banned-import from narwhals import dtypes @@ -140,7 +141,7 @@ def validate_dataframe_comparand( return NotImplemented if isinstance(other, ArrowSeries): if len(other) == 1: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa # ignore-banned-import value = other.item() if backend_version < (13,) and hasattr(value, "as_py"): # pragma: no cover @@ -157,7 +158,7 @@ def horizontal_concat(dfs: list[Any]) -> Any: Should be in namespace. """ - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa # ignore-banned-import if not dfs: msg = "No dataframes to concatenate" # pragma: no cover @@ -190,7 +191,7 @@ def vertical_concat(dfs: list[Any]) -> Any: msg = "unable to vstack, column names don't match" raise TypeError(msg) - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa # ignore-banned-import return pa.concat_tables(dfs).combine_chunks() @@ -198,8 +199,8 @@ def vertical_concat(dfs: list[Any]) -> Any: def floordiv_compat(left: Any, right: Any) -> Any: # The following lines are adapted from pandas' pyarrow implementation. # Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154 - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa # ignore-banned-import + import pyarrow.compute as pc # ignore-banned-import if isinstance(left, (int, float)): left = pa.scalar(left) @@ -237,8 +238,8 @@ def floordiv_compat(left: Any, right: Any) -> Any: def cast_for_truediv(arrow_array: Any, pa_object: Any) -> tuple[Any, Any]: # Lifted from: # https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122 - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa # ignore-banned-import + import pyarrow.compute as pc # ignore-banned-import # Ensure int / int -> float mirroring Python/Numpy behavior # as pc.divide_checked(int, int) -> int @@ -260,7 +261,7 @@ def broadcast_series(series: list[ArrowSeries]) -> list[Any]: if fast_path: return [s._native_series for s in series] - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa # ignore-banned-import reshaped = [] for s, length in zip(series, lengths): @@ -274,3 +275,14 @@ def broadcast_series(series: list[ArrowSeries]) -> list[Any]: reshaped.append(s_native) return reshaped + + +def convert_slice_to_nparray( + num_rows: int, rows_slice: slice | int | Sequence[int] +) -> Any: + import numpy as np # ignore-banned-import + + if isinstance(rows_slice, slice): + return np.arange(num_rows)[rows_slice] + else: + return rows_slice diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index ce5999aa2..b51d53baa 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -602,11 +602,15 @@ def __getitem__(self, item: tuple[slice, Sequence[int]]) -> Self: ... @overload def __getitem__(self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap] @overload + def __getitem__(self, item: tuple[slice, str]) -> Series: ... # type: ignore[overload-overlap] + @overload def __getitem__(self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ... @overload def __getitem__(self, item: tuple[slice, Sequence[str]]) -> Self: ... @overload def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap] + @overload + def __getitem__(self, item: tuple[slice, int]) -> Series: ... # type: ignore[overload-overlap] @overload def __getitem__(self, item: Sequence[int]) -> Self: ... @@ -627,6 +631,7 @@ def __getitem__( | Sequence[int] | Sequence[str] | tuple[Sequence[int], str | int] + | tuple[slice, str | int] | tuple[slice | Sequence[int], Sequence[int] | Sequence[str] | slice], ) -> Series | Self: """ diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index 862ba5d1a..fa98fd96f 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -78,21 +78,25 @@ def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ... def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ... @overload def __getitem__(self, item: tuple[slice, Sequence[int]]) -> Self: ... - @overload def __getitem__(self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap] @overload + def __getitem__(self, item: tuple[slice, str]) -> Series: ... # type: ignore[overload-overlap] + @overload def __getitem__(self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ... @overload def __getitem__(self, item: tuple[slice, Sequence[str]]) -> Self: ... - @overload def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap] + @overload + def __getitem__(self, item: tuple[slice, int]) -> Series: ... # type: ignore[overload-overlap] @overload def __getitem__(self, item: Sequence[int]) -> Self: ... + @overload def __getitem__(self, item: str) -> Series: ... # type: ignore[overload-overlap] + @overload def __getitem__(self, item: Sequence[str]) -> Self: ... diff --git a/tests/frame/slice_test.py b/tests/frame/slice_test.py index af0e57f20..0867844f9 100644 --- a/tests/frame/slice_test.py +++ b/tests/frame/slice_test.py @@ -149,6 +149,9 @@ def test_slice_slice_columns(constructor_eager: Any) -> None: result = df[:, [0, 2]] expected = {"a": [1, 2, 3], "c": [7, 8, 9]} compare_dicts(result, expected) + result = df[:2, [0, 2]] + expected = {"a": [1, 2], "c": [7, 8]} + compare_dicts(result, expected) result = df[["b", "c"]] expected = {"b": [4, 5, 6], "c": [7, 8, 9]} compare_dicts(result, expected) diff --git a/tests/series_only/slice_test.py b/tests/series_only/slice_test.py index f9d2b4e2f..48cf15bc7 100644 --- a/tests/series_only/slice_test.py +++ b/tests/series_only/slice_test.py @@ -13,3 +13,15 @@ def test_slice(constructor_eager: Any) -> None: result = {"a": df["a"][1:]} expected = {"a": [2, 3]} compare_dicts(result, expected) + result = {"b": df[:, 1]} + expected = {"b": [4, 5, 6]} + compare_dicts(result, expected) + result = {"b": df[:, "b"]} + expected = {"b": [4, 5, 6]} + compare_dicts(result, expected) + result = {"b": df[:2, "b"]} + expected = {"b": [4, 5]} + compare_dicts(result, expected) + result = {"b": df[:2, 1]} + expected = {"b": [4, 5]} + compare_dicts(result, expected)