From 9dfd6f5455648d32e9e2eff37764c3ac1d1715b1 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Mon, 26 Aug 2024 19:37:59 +0200 Subject: [PATCH] patch: stableify `concat` function (#869) * patch: stableify concat function * @overload * test `to_lazy` * supposed to raise due to type mismatch --- narwhals/stable/v1.py | 25 ++++++++++++++++++++++++- tests/frame/concat_test.py | 6 +++--- tests/frame/test_invalid.py | 2 +- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index d166a814f..1b64b78f1 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -37,7 +37,6 @@ from narwhals.expr import Then as NwThen from narwhals.expr import When as NwWhen from narwhals.expr import when as nw_when -from narwhals.functions import concat from narwhals.functions import show_versions from narwhals.schema import Schema as NwSchema from narwhals.series import Series as NwSeries @@ -1338,6 +1337,30 @@ def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return _stableify(nw.mean_horizontal(*exprs)) +@overload +def concat( + items: Iterable[DataFrame[Any]], + *, + how: Literal["horizontal", "vertical"] = "vertical", +) -> DataFrame[Any]: ... + + +@overload +def concat( + items: Iterable[LazyFrame[Any]], + *, + how: Literal["horizontal", "vertical"] = "vertical", +) -> LazyFrame[Any]: ... + + +def concat( + items: Iterable[DataFrame[Any] | LazyFrame[Any]], + *, + how: Literal["horizontal", "vertical"] = "vertical", +) -> DataFrame[Any] | LazyFrame[Any]: + return _stableify(nw.concat(items, how=how)) # type: ignore[no-any-return] + + def is_ordered_categorical(series: Series) -> bool: """ Return whether indices of categories are semantically meaningful. diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index 970220bf2..44b0f6e1a 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -10,10 +10,10 @@ def test_concat_horizontal(constructor: Any, request: Any) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df_left = nw.from_native(constructor(data)) + df_left = nw.from_native(constructor(data)).lazy() data_right = {"c": [6, 12, -1], "d": [0, -4, 2]} - df_right = nw.from_native(constructor(data_right)) + df_right = nw.from_native(constructor(data_right)).lazy() result = nw.concat([df_left, df_right], how="horizontal") expected = { @@ -34,7 +34,7 @@ def test_concat_vertical(constructor: Any, request: Any) -> None: request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_left = ( - nw.from_native(constructor(data)).rename({"a": "c", "b": "d"}).drop("z").lazy() + nw.from_native(constructor(data)).lazy().rename({"a": "c", "b": "d"}).drop("z") ) data_right = {"c": [6, 12, -1], "d": [0, -4, 2]} diff --git a/tests/frame/test_invalid.py b/tests/frame/test_invalid.py index cf1fff6d1..b8bca586f 100644 --- a/tests/frame/test_invalid.py +++ b/tests/frame/test_invalid.py @@ -24,7 +24,7 @@ def test_validate_laziness() -> None: NotImplementedError, match=("The items to concatenate should either all be eager, or all lazy"), ): - nw.concat([nw.from_native(df, eager_only=True), nw.from_native(df).lazy()]) + nw.concat([nw.from_native(df, eager_only=True), nw.from_native(df).lazy()]) # type: ignore[list-item] @pytest.mark.skipif(