diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index fa5a69950..f409ef735 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -121,7 +121,12 @@ def __getitem__(self, item: str) -> ArrowSeries: ... def __getitem__(self, item: slice) -> ArrowDataFrame: ... def __getitem__( - self, item: str | slice | Sequence[int] | tuple[Sequence[int], str | int] + self, + item: str + | slice + | Sequence[int] + | Sequence[str] + | tuple[Sequence[int], str | int], ) -> ArrowSeries | ArrowDataFrame: if isinstance(item, str): from narwhals._arrow.series import ArrowSeries @@ -191,6 +196,8 @@ def __getitem__( ) elif isinstance(item, Sequence) or (is_numpy_array(item) and item.ndim == 1): + if isinstance(item, Sequence) and all(isinstance(x, str) for x in item): + return self._from_native_frame(self._native_frame.select(item)) return self._from_native_frame(self._native_frame.take(item)) else: # pragma: no cover diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 499777833..71a659998 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -111,13 +111,22 @@ def __getitem__(self, item: tuple[Sequence[int], str | int]) -> PandasLikeSeries def __getitem__(self, item: Sequence[int]) -> PandasLikeDataFrame: ... @overload - def __getitem__(self, item: str) -> PandasLikeSeries: ... + def __getitem__(self, item: str) -> PandasLikeSeries: ... # type: ignore[overload-overlap] + + @overload + def __getitem__(self, item: Sequence[str]) -> PandasLikeDataFrame: ... @overload def __getitem__(self, item: slice) -> PandasLikeDataFrame: ... def __getitem__( - self, item: str | slice | Sequence[int] | tuple[Sequence[int], str | int] + self, + item: str + | int + | slice + | Sequence[int] + | Sequence[str] + | tuple[Sequence[int], str | int], ) -> PandasLikeSeries | PandasLikeDataFrame: if isinstance(item, str): from narwhals._pandas_like.series import PandasLikeSeries @@ -174,7 +183,7 @@ def __getitem__( from narwhals._pandas_like.series import PandasLikeSeries if isinstance(item[1], str): - item = (item[0], self._native_frame.columns.get_loc(item[1])) + item = (item[0], self._native_frame.columns.get_loc(item[1])) # type: ignore[assignment] native_series = self._native_frame.iloc[item] elif isinstance(item[1], int): native_series = self._native_frame.iloc[item] @@ -191,6 +200,8 @@ def __getitem__( elif isinstance(item, (slice, Sequence)) or ( is_numpy_array(item) and item.ndim == 1 ): + if isinstance(item, Sequence) and all(isinstance(x, str) for x in item): + return self._from_native_frame(self._native_frame.loc[:, item]) return self._from_native_frame(self._native_frame.iloc[item]) else: # pragma: no cover diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index a266b73c7..1b91f0910 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -612,7 +612,10 @@ def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: i def __getitem__(self, item: Sequence[int]) -> Self: ... @overload - def __getitem__(self, item: str) -> Series: ... + def __getitem__(self, item: str) -> Series: ... # type: ignore[overload-overlap] + + @overload + def __getitem__(self, item: Sequence[str]) -> Self: ... @overload def __getitem__(self, item: slice) -> Self: ... @@ -622,6 +625,7 @@ def __getitem__( item: str | slice | Sequence[int] + | Sequence[str] | tuple[Sequence[int], str | int] | tuple[slice | Sequence[int], Sequence[int] | Sequence[str] | slice], ) -> Series | Self: @@ -644,6 +648,8 @@ def __getitem__( `DataFrame`. - `df[:, ['a', 'c']]` extracts all rows and columns `'a'` and `'c'` and returns a `DataFrame`. + - `df[['a', 'c']]` extracts all rows and columns `'a'` and `'c'` and returns a + `DataFrame`. - `df[0: 2, ['a', 'c']]` extracts the first two rows and columns `'a'` and `'c'` and returns a `DataFrame` - `df[:, 0: 2]` extracts all rows from the first two columns and returns a `DataFrame` diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index 78cfa5ba1..862ba5d1a 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -91,9 +91,10 @@ def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: i @overload def __getitem__(self, item: Sequence[int]) -> Self: ... - @overload - def __getitem__(self, item: str) -> Series: ... + def __getitem__(self, item: str) -> Series: ... # type: ignore[overload-overlap] + @overload + def __getitem__(self, item: Sequence[str]) -> Self: ... @overload def __getitem__(self, item: slice) -> Self: ... diff --git a/tests/frame/slice_test.py b/tests/frame/slice_test.py index 18b05bf3b..834e88bff 100644 --- a/tests/frame/slice_test.py +++ b/tests/frame/slice_test.py @@ -147,6 +147,9 @@ def test_slice_slice_columns(constructor_eager: Any) -> None: result = df[:, [0, 2]] expected = {"a": [1, 2, 3], "c": [7, 8, 9]} compare_dicts(result, expected) + result = df[["b", "c"]] + expected = {"b": [4, 5, 6], "c": [7, 8, 9]} + compare_dicts(result, expected) def test_slice_invalid(constructor_eager: Any) -> None: