From d6f3cd7f440013c3c5e699619139ecf3fffce07a Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Wed, 11 Sep 2024 14:14:56 +0100 Subject: [PATCH] add support for PyArrow: multi-element `__getitem_`_ (#949) --- narwhals/_arrow/series.py | 2 ++ tests/series_only/slice_test.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 tests/series_only/slice_test.py diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 8765605fd..73390fdd3 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -321,6 +321,8 @@ def __getitem__(self, idx: slice | Sequence[int]) -> Self: ... def __getitem__(self, idx: int | slice | Sequence[int]) -> Any | Self: if isinstance(idx, int): return self._native_series[idx] + if isinstance(idx, Sequence): + return self._from_native_series(self._native_series.take(idx)) return self._from_native_series(self._native_series[idx]) def scatter(self, indices: int | Sequence[int], values: Any) -> Self: diff --git a/tests/series_only/slice_test.py b/tests/series_only/slice_test.py new file mode 100644 index 000000000..f9d2b4e2f --- /dev/null +++ b/tests/series_only/slice_test.py @@ -0,0 +1,15 @@ +from typing import Any + +import narwhals.stable.v1 as nw +from tests.utils import compare_dicts + + +def test_slice(constructor_eager: Any) -> None: + data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [1, 4, 2]} + df = nw.from_native(constructor_eager(data), eager_only=True) + result = {"a": df["a"][[0, 1]]} + expected = {"a": [1, 2]} + compare_dicts(result, expected) + result = {"a": df["a"][1:]} + expected = {"a": [2, 3]} + compare_dicts(result, expected)