Skip to content

Commit

Permalink
support __getitem__ with single tuple of column names (#935)
Browse files Browse the repository at this point in the history
  • Loading branch information
raisadz authored Sep 10, 2024
1 parent 270adbd commit 359905b
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 7 deletions.
9 changes: 8 additions & 1 deletion narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ def __getitem__(self, item: str) -> ArrowSeries: ...
def __getitem__(self, item: slice) -> ArrowDataFrame: ...

def __getitem__(
self, item: str | slice | Sequence[int] | tuple[Sequence[int], str | int]
self,
item: str
| slice
| Sequence[int]
| Sequence[str]
| tuple[Sequence[int], str | int],
) -> ArrowSeries | ArrowDataFrame:
if isinstance(item, str):
from narwhals._arrow.series import ArrowSeries
Expand Down Expand Up @@ -191,6 +196,8 @@ def __getitem__(
)

elif isinstance(item, Sequence) or (is_numpy_array(item) and item.ndim == 1):
if isinstance(item, Sequence) and all(isinstance(x, str) for x in item):
return self._from_native_frame(self._native_frame.select(item))
return self._from_native_frame(self._native_frame.take(item))

else: # pragma: no cover
Expand Down
17 changes: 14 additions & 3 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,22 @@ def __getitem__(self, item: tuple[Sequence[int], str | int]) -> PandasLikeSeries
def __getitem__(self, item: Sequence[int]) -> PandasLikeDataFrame: ...

@overload
def __getitem__(self, item: str) -> PandasLikeSeries: ...
def __getitem__(self, item: str) -> PandasLikeSeries: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self, item: Sequence[str]) -> PandasLikeDataFrame: ...

@overload
def __getitem__(self, item: slice) -> PandasLikeDataFrame: ...

def __getitem__(
self, item: str | slice | Sequence[int] | tuple[Sequence[int], str | int]
self,
item: str
| int
| slice
| Sequence[int]
| Sequence[str]
| tuple[Sequence[int], str | int],
) -> PandasLikeSeries | PandasLikeDataFrame:
if isinstance(item, str):
from narwhals._pandas_like.series import PandasLikeSeries
Expand Down Expand Up @@ -174,7 +183,7 @@ def __getitem__(
from narwhals._pandas_like.series import PandasLikeSeries

if isinstance(item[1], str):
item = (item[0], self._native_frame.columns.get_loc(item[1]))
item = (item[0], self._native_frame.columns.get_loc(item[1])) # type: ignore[assignment]
native_series = self._native_frame.iloc[item]
elif isinstance(item[1], int):
native_series = self._native_frame.iloc[item]
Expand All @@ -191,6 +200,8 @@ def __getitem__(
elif isinstance(item, (slice, Sequence)) or (
is_numpy_array(item) and item.ndim == 1
):
if isinstance(item, Sequence) and all(isinstance(x, str) for x in item):
return self._from_native_frame(self._native_frame.loc[:, item])
return self._from_native_frame(self._native_frame.iloc[item])

else: # pragma: no cover
Expand Down
8 changes: 7 additions & 1 deletion narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,10 @@ def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: i
def __getitem__(self, item: Sequence[int]) -> Self: ...

@overload
def __getitem__(self, item: str) -> Series: ...
def __getitem__(self, item: str) -> Series: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self, item: Sequence[str]) -> Self: ...

@overload
def __getitem__(self, item: slice) -> Self: ...
Expand All @@ -622,6 +625,7 @@ def __getitem__(
item: str
| slice
| Sequence[int]
| Sequence[str]
| tuple[Sequence[int], str | int]
| tuple[slice | Sequence[int], Sequence[int] | Sequence[str] | slice],
) -> Series | Self:
Expand All @@ -644,6 +648,8 @@ def __getitem__(
`DataFrame`.
- `df[:, ['a', 'c']]` extracts all rows and columns `'a'` and `'c'` and returns a
`DataFrame`.
- `df[['a', 'c']]` extracts all rows and columns `'a'` and `'c'` and returns a
`DataFrame`.
- `df[0: 2, ['a', 'c']]` extracts the first two rows and columns `'a'` and `'c'` and
returns a `DataFrame`
- `df[:, 0: 2]` extracts all rows from the first two columns and returns a `DataFrame`
Expand Down
5 changes: 3 additions & 2 deletions narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: i

@overload
def __getitem__(self, item: Sequence[int]) -> Self: ...

@overload
def __getitem__(self, item: str) -> Series: ...
def __getitem__(self, item: str) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: Sequence[str]) -> Self: ...

@overload
def __getitem__(self, item: slice) -> Self: ...
Expand Down
3 changes: 3 additions & 0 deletions tests/frame/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def test_slice_slice_columns(constructor_eager: Any) -> None:
result = df[:, [0, 2]]
expected = {"a": [1, 2, 3], "c": [7, 8, 9]}
compare_dicts(result, expected)
result = df[["b", "c"]]
expected = {"b": [4, 5, 6], "c": [7, 8, 9]}
compare_dicts(result, expected)


def test_slice_invalid(constructor_eager: Any) -> None:
Expand Down

0 comments on commit 359905b

Please sign in to comment.