diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index b3163fef6..c76593404 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -1,6 +1,5 @@ from __future__ import annotations -from copy import copy from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -12,6 +11,7 @@ from narwhals._dask.utils import binary_operation_returns_scalar from narwhals._dask.utils import maybe_evaluate from narwhals._dask.utils import narwhals_to_native_dtype +from narwhals._expression_parsing import infer_new_root_output_names from narwhals._pandas_like.utils import calculate_timestamp_date from narwhals._pandas_like.utils import calculate_timestamp_datetime from narwhals._pandas_like.utils import native_to_narwhals_dtype @@ -148,30 +148,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: results.append(result) return results - # Try tracking root and output names by combining them from all - # expressions appearing in args and kwargs. If any anonymous - # expression appears (e.g. nw.all()), then give up on tracking root names - # and just set it to None. - root_names = copy(self._root_names) - output_names = self._output_names - for arg in list(kwargs.values()): - if root_names is not None and isinstance(arg, self.__class__): - if arg._root_names is not None: - root_names.extend(arg._root_names) - else: - root_names = None - output_names = None - break - elif root_names is None: - output_names = None - break - - if not ( - (output_names is None and root_names is None) - or (output_names is not None and root_names is not None) - ): # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) + root_names, output_names = infer_new_root_output_names(self, **kwargs) return self.__class__( func, diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 4d51eb719..99bb3bb24 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -125,6 +125,38 @@ def parse_into_expr( raise InvalidIntoExprError.from_invalid_type(type(into_expr)) +def infer_new_root_output_names( + expr: CompliantExpr[Any], **kwargs: Any +) -> tuple[list[str] | None, list[str] | None]: + """Return new root and output names after chaining expressions. + + Try tracking root and output names by combining them from all expressions appearing in kwargs. + If any anonymous expression appears (e.g. nw.all()), then give up on tracking root names + and just set it to None. + """ + root_names = copy(expr._root_names) + output_names = expr._output_names + for arg in list(kwargs.values()): + if root_names is not None and isinstance(arg, expr.__class__): + if arg._root_names is not None: + root_names.extend(arg._root_names) + else: + root_names = None + output_names = None + break + elif root_names is None: + output_names = None + break + + if not ( + (output_names is None and root_names is None) + or (output_names is not None and root_names is not None) + ): # pragma: no cover + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) + return root_names, output_names + + @overload def reuse_series_implementation( expr: PandasLikeExprT, @@ -201,30 +233,8 @@ def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]: raise AssertionError(msg) return out - # Try tracking root and output names by combining them from all - # expressions appearing in args and kwargs. If any anonymous - # expression appears (e.g. nw.all()), then give up on tracking root names - # and just set it to None. - root_names = copy(expr._root_names) - output_names = expr._output_names - for arg in list(kwargs.values()): - if root_names is not None and isinstance(arg, expr.__class__): - if arg._root_names is not None: - root_names.extend(arg._root_names) - else: - root_names = None - output_names = None - break - elif root_names is None: - output_names = None - break + root_names, output_names = infer_new_root_output_names(expr, **kwargs) - if not ( - (output_names is None and root_names is None) - or (output_names is not None and root_names is not None) - ): # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) return plx._create_expr_from_callable( # type: ignore[return-value] func, # type: ignore[arg-type] depth=expr._depth + 1, diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 1b98fcc46..b74aea678 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -1,11 +1,11 @@ from __future__ import annotations -from copy import copy from typing import TYPE_CHECKING from typing import Any from typing import Callable from typing import Sequence +from narwhals._expression_parsing import infer_new_root_output_names from narwhals._spark_like.utils import get_column_name from narwhals._spark_like.utils import maybe_evaluate from narwhals.typing import CompliantExpr @@ -106,30 +106,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: results.append(column_result) return results - # Try tracking root and output names by combining them from all - # expressions appearing in args and kwargs. If any anonymous - # expression appears (e.g. nw.all()), then give up on tracking root names - # and just set it to None. - root_names = copy(self._root_names) - output_names = self._output_names - for arg in list(kwargs.values()): - if root_names is not None and isinstance(arg, self.__class__): - if arg._root_names is not None: - root_names.extend(arg._root_names) - else: # pragma: no cover - root_names = None - output_names = None - break - elif root_names is None: - output_names = None - break - - if not ( - (output_names is None and root_names is None) - or (output_names is not None and root_names is not None) - ): # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) + root_names, output_names = infer_new_root_output_names(self, **kwargs) return self.__class__( func, diff --git a/pyproject.toml b/pyproject.toml index 0c2b4a9be..43a1dbc12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -159,6 +159,8 @@ filterwarnings = [ 'ignore:.*defaulting to pandas implementation', 'ignore:.*implementation has mismatches with pandas', 'ignore:.*You are using pyarrow version', + # This warning was temporarily raised by pandas but then reverted. + 'ignore:.*Passing a BlockManager to DataFrame:DeprecationWarning', ] xfail_strict = true markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]