From 36dacf91886d67333f6127ebb70cd2f5bdeeeea4 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Thu, 9 Jan 2025 09:22:13 +0000 Subject: [PATCH] feat: implement `n_unique` for DuckDB (#1762) --- narwhals/_duckdb/expr.py | 22 ++++++++++++++++++++++ tests/expr_and_series/n_unique_test.py | 6 +----- tests/expr_and_series/unary_test.py | 6 +----- tests/group_by_test.py | 8 +------- tpch/execute.py | 2 +- 5 files changed, 26 insertions(+), 18 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 4515cbba1..e5e612085 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -408,6 +408,28 @@ def sum(self) -> Self: lambda _input: FunctionExpression("sum", _input), "sum", returns_scalar=True ) + def n_unique(self) -> Self: + from duckdb import CaseExpression + from duckdb import ConstantExpression + from duckdb import FunctionExpression + + def func(_input: duckdb.Expression) -> duckdb.Expression: + # https://stackoverflow.com/a/79338887/4451315 + return FunctionExpression( + "array_unique", FunctionExpression("array_agg", _input) + ) + FunctionExpression( + "max", + CaseExpression( + condition=_input.isnotnull(), value=ConstantExpression(0) + ).otherwise(ConstantExpression(1)), + ) + + return self._from_call( + func, + "n_unique", + returns_scalar=True, + ) + def count(self) -> Self: from duckdb import FunctionExpression diff --git a/tests/expr_and_series/n_unique_test.py b/tests/expr_and_series/n_unique_test.py index d8e4d9b77..90bffb04b 100644 --- a/tests/expr_and_series/n_unique_test.py +++ b/tests/expr_and_series/n_unique_test.py @@ -1,7 +1,5 @@ from __future__ import annotations -import pytest - import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -13,9 +11,7 @@ } -def test_n_unique(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_n_unique(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.all().n_unique()) expected = {"a": [3], "b": [4]} diff --git a/tests/expr_and_series/unary_test.py b/tests/expr_and_series/unary_test.py index 9ee38a230..f3e01d80f 100644 --- a/tests/expr_and_series/unary_test.py +++ b/tests/expr_and_series/unary_test.py @@ -126,11 +126,7 @@ def test_unary_two_elements_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_unary_one_element( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_unary_one_element(constructor: Constructor) -> None: data = {"a": [1], "b": [2], "c": [None]} # Dask runs into a divide by zero RuntimeWarning for 1 element skew. context = ( diff --git a/tests/group_by_test.py b/tests/group_by_test.py index 0dd6d8a10..c854da453 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -115,8 +115,6 @@ def test_group_by_depth_1_agg( expected: dict[str, list[int | float]], request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor) and attr == "n_unique": - request.applymarker(pytest.mark.xfail) if "pandas_pyarrow" in str(constructor) and attr == "var" and PANDAS_VERSION < (2, 1): # Known issue with variance calculation in pandas 2.0.x with pyarrow backend in groupby operations" request.applymarker(pytest.mark.xfail) @@ -166,11 +164,7 @@ def test_group_by_median(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_group_by_n_unique_w_missing( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_group_by_n_unique_w_missing(constructor: Constructor) -> None: data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]} result = ( nw.from_native(constructor(data)) diff --git a/tpch/execute.py b/tpch/execute.py index 1f3823ced..f2f3041df 100644 --- a/tpch/execute.py +++ b/tpch/execute.py @@ -40,7 +40,7 @@ "dask": lambda x: x.compute(), } -DUCKDB_XFAILS = ["q11", "q14", "q15", "q16", "q18", "q22"] +DUCKDB_XFAILS = ["q11", "q14", "q15", "q18", "q22"] QUERY_DATA_PATH_MAP = { "q1": (LINEITEM_PATH,),