Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Oct 25, 2024
1 parent c3f7135 commit 1496e49
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 36 deletions.
13 changes: 9 additions & 4 deletions narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from narwhals._expression_parsing import is_simple_aggregation
from narwhals._expression_parsing import parse_into_exprs
from narwhals._pandas_like.utils import native_series_from_iterable
from narwhals._pandas_like.utils import reset_index_no_copy
from narwhals.utils import Implementation
from narwhals.utils import remove_prefix

Expand Down Expand Up @@ -189,13 +188,17 @@ def agg_pandas( # noqa: PLR0915
result_simple_aggs = result_simple_aggs.rename(
columns=name_mapping, copy=False
)
reset_index_no_copy(result_simple_aggs, native_namespace)
# Keep inplace=True to avoid making a redundant copy.
# This one will likely stay in pandas https://github.com/pandas-dev/pandas/pull/51466/files
result_simple_aggs.reset_index(inplace=True) # noqa: PD002
if nunique_aggs:
result_nunique_aggs = grouped[list(nunique_aggs.values())].nunique(
dropna=False
)
result_nunique_aggs.columns = list(nunique_aggs.keys())
reset_index_no_copy(result_nunique_aggs, native_namespace)
# Keep inplace=True to avoid making a redundant copy.
# This one will likely stay in pandas https://github.com/pandas-dev/pandas/pull/51466/files
result_nunique_aggs.reset_index(inplace=True) # noqa: PD002
if simple_aggs and nunique_aggs:
if (
set(result_simple_aggs.columns)
Expand Down Expand Up @@ -261,6 +264,8 @@ def func(df: Any) -> Any:
else: # pragma: no cover
result_complex = grouped.apply(func)

reset_index_no_copy(result_complex, native_namespace)
# Keep inplace=True to avoid making a redundant copy.
# This one will likely stay in pandas https://github.com/pandas-dev/pandas/pull/51466/files
result_complex.reset_index(inplace=True) # noqa: PD002

return from_dataframe(result_complex.loc[:, output_names])
5 changes: 4 additions & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,10 @@ def value_counts(
dropna=False,
sort=False,
normalize=normalize,
).reset_index()
)
# Keep inplace=True to avoid making a redundant copy.
# This one will likely stay in pandas https://github.com/pandas-dev/pandas/pull/51466/files
val_count.reset_index(inplace=True) # noqa: PD002

val_count.columns = [index_name_, value_name_]

Expand Down
19 changes: 0 additions & 19 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,22 +590,3 @@ def calculate_timestamp_date(s: pd.Series, time_unit: str) -> pd.Series:
else:
result = s * 1_000
return result


def reset_index_no_copy(native_dataframe: pd.DataFrame, native_namespace: Any) -> None:
"""
Reset index without triggering a copy.
pandas' reset_index isn't free, and creates a copy. To avoid that, and keep
overhead as low as possible, we introduce a utility to use instead of that.
This should be used internally in Narwhals whenever you need to reset the index
of an object which is not the original object passed by the user. We should
never mutate the user's object, but if we create intermediate objects ourselves,
then it's fine to mutate them.
"""
if set(native_dataframe.index.names).intersection(native_dataframe.columns):
msg = f"Cannot insert column with name {native_dataframe.index.name} into dataframe with columns {native_dataframe.columns}"
raise ValueError(msg)
for i, name in enumerate(native_dataframe.index.names):
native_dataframe.insert(0, name, native_dataframe.index.get_level_values(i))
native_dataframe.index = native_namespace.RangeIndex(len(native_dataframe))
12 changes: 0 additions & 12 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from pandas.testing import assert_series_equal

import narwhals.stable.v1 as nw
from narwhals._pandas_like.utils import reset_index_no_copy
from tests.utils import PANDAS_VERSION
from tests.utils import get_module_version_as_tuple

Expand Down Expand Up @@ -148,14 +147,3 @@ def test_maybe_convert_dtypes_polars() -> None:
def test_get_trivial_version_with_uninstalled_module() -> None:
result = get_module_version_as_tuple("non_existent_module")
assert result == (0, 0, 0)


def test_reset_index_no_copy() -> None:
df = pd.DataFrame({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]}).set_index("a")
reset_index_no_copy(df, pd)
pd.testing.assert_frame_equal(
df, pd.DataFrame({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]})
)
df.index.name = "b"
with pytest.raises(ValueError, match="Cannot insert column"):
reset_index_no_copy(df, pd)

0 comments on commit 1496e49

Please sign in to comment.