Skip to content

Commit

Permalink
support __getitem__ with single tuple of column names
Browse files Browse the repository at this point in the history
  • Loading branch information
raisadz committed Sep 9, 2024
1 parent e7b3b83 commit 6653789
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 7 deletions.
9 changes: 8 additions & 1 deletion narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ def __getitem__(self, item: str) -> ArrowSeries: ...
def __getitem__(self, item: slice) -> ArrowDataFrame: ...

def __getitem__(
self, item: str | slice | Sequence[int] | tuple[Sequence[int], str | int]
self,
item: str
| slice
| Sequence[int]
| Sequence[str]
| tuple[Sequence[int], str | int],
) -> ArrowSeries | ArrowDataFrame:
if isinstance(item, str):
from narwhals._arrow.series import ArrowSeries
Expand Down Expand Up @@ -191,6 +196,8 @@ def __getitem__(
)

elif isinstance(item, Sequence) or (is_numpy_array(item) and item.ndim == 1):
if isinstance(item, Sequence) and all(isinstance(x, str) for x in item):
return self._from_native_frame(self._native_frame.select(item))
return self._from_native_frame(self._native_frame.take(item))

else: # pragma: no cover
Expand Down
19 changes: 16 additions & 3 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,30 @@ def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray:
return self.to_numpy(dtype=dtype, copy=copy)

@overload
def __getitem__(self, item: tuple[Sequence[int], str | int]) -> PandasLikeSeries: ... # type: ignore[overload-overlap]
def __getitem__( # type: ignore[overload-overlap]
self, item: tuple[Sequence[int], Sequence[str], str | int]
) -> PandasLikeSeries: ...

@overload
def __getitem__(self, item: Sequence[int]) -> PandasLikeDataFrame: ...

@overload
def __getitem__(self, item: str) -> PandasLikeSeries: ...
def __getitem__(self, item: str) -> PandasLikeSeries: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self, item: Sequence[str]) -> PandasLikeDataFrame: ...

@overload
def __getitem__(self, item: slice) -> PandasLikeDataFrame: ...

def __getitem__(
self, item: str | slice | Sequence[int] | tuple[Sequence[int], str | int]
self,
item: str
| int
| slice
| Sequence[int]
| Sequence[str]
| tuple[Sequence[int], Sequence[str], str | int],
) -> PandasLikeSeries | PandasLikeDataFrame:
if isinstance(item, str):
from narwhals._pandas_like.series import PandasLikeSeries
Expand Down Expand Up @@ -191,6 +202,8 @@ def __getitem__(
elif isinstance(item, (slice, Sequence)) or (
is_numpy_array(item) and item.ndim == 1
):
if isinstance(item, Sequence) and all(isinstance(x, str) for x in item):
return self._from_native_frame(self._native_frame[item])
return self._from_native_frame(self._native_frame.iloc[item])

else: # pragma: no cover
Expand Down
8 changes: 7 additions & 1 deletion narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,10 @@ def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: i
def __getitem__(self, item: Sequence[int]) -> Self: ...

@overload
def __getitem__(self, item: str) -> Series: ...
def __getitem__(self, item: Sequence[str]) -> Self: ...

@overload
def __getitem__(self, item: str) -> Series: ... # type: ignore[misc]

@overload
def __getitem__(self, item: slice) -> Self: ...
Expand All @@ -620,6 +623,7 @@ def __getitem__(
item: str
| slice
| Sequence[int]
| Sequence[str]
| tuple[Sequence[int], str | int]
| tuple[slice | Sequence[int], Sequence[int] | Sequence[str] | slice],
) -> Series | Self:
Expand All @@ -642,6 +646,8 @@ def __getitem__(
`DataFrame`.
- `df[:, ['a', 'c']]` extracts all rows and columns `'a'` and `'c'` and returns a
`DataFrame`.
- `df[['a', 'c']]` extracts all rows and columns `'a'` and `'c'` and returns a
`DataFrame`.
- `df[0: 2, ['a', 'c']]` extracts the first two rows and columns `'a'` and `'c'` and
returns a `DataFrame`
- `df[:, 0: 2]` extracts all rows from the first two columns and returns a `DataFrame`
Expand Down
5 changes: 3 additions & 2 deletions narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: i

@overload
def __getitem__(self, item: Sequence[int]) -> Self: ...

@overload
def __getitem__(self, item: str) -> Series: ...
def __getitem__(self, item: Sequence[str]) -> Self: ...
@overload
def __getitem__(self, item: str) -> Series: ... # type: ignore[misc]

@overload
def __getitem__(self, item: slice) -> Self: ...
Expand Down
3 changes: 3 additions & 0 deletions tests/frame/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def test_slice_slice_columns(constructor_eager: Any) -> None:
result = df[:, [0, 2]]
expected = {"a": [1, 2, 3], "c": [7, 8, 9]}
compare_dicts(result, expected)
result = df[["b", "c"]]
expected = {"b": [4, 5, 6], "c": [7, 8, 9]}
compare_dicts(result, expected)


def test_slice_invalid(constructor_eager: Any) -> None:
Expand Down

0 comments on commit 6653789

Please sign in to comment.