Skip to content

Commit

Permalink
feat: implement n_unique for DuckDB (#1762)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jan 9, 2025
1 parent 40a83e3 commit 36dacf9
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 18 deletions.
22 changes: 22 additions & 0 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,28 @@ def sum(self) -> Self:
lambda _input: FunctionExpression("sum", _input), "sum", returns_scalar=True
)

def n_unique(self) -> Self:
from duckdb import CaseExpression
from duckdb import ConstantExpression
from duckdb import FunctionExpression

def func(_input: duckdb.Expression) -> duckdb.Expression:
# https://stackoverflow.com/a/79338887/4451315
return FunctionExpression(
"array_unique", FunctionExpression("array_agg", _input)
) + FunctionExpression(
"max",
CaseExpression(
condition=_input.isnotnull(), value=ConstantExpression(0)
).otherwise(ConstantExpression(1)),
)

return self._from_call(
func,
"n_unique",
returns_scalar=True,
)

def count(self) -> Self:
from duckdb import FunctionExpression

Expand Down
6 changes: 1 addition & 5 deletions tests/expr_and_series/n_unique_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
Expand All @@ -13,9 +11,7 @@
}


def test_n_unique(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_n_unique(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.all().n_unique())
expected = {"a": [3], "b": [4]}
Expand Down
6 changes: 1 addition & 5 deletions tests/expr_and_series/unary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,7 @@ def test_unary_two_elements_series(constructor_eager: ConstructorEager) -> None:
assert_equal_data(result, expected)


def test_unary_one_element(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_unary_one_element(constructor: Constructor) -> None:
data = {"a": [1], "b": [2], "c": [None]}
# Dask runs into a divide by zero RuntimeWarning for 1 element skew.
context = (
Expand Down
8 changes: 1 addition & 7 deletions tests/group_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ def test_group_by_depth_1_agg(
expected: dict[str, list[int | float]],
request: pytest.FixtureRequest,
) -> None:
if "duckdb" in str(constructor) and attr == "n_unique":
request.applymarker(pytest.mark.xfail)
if "pandas_pyarrow" in str(constructor) and attr == "var" and PANDAS_VERSION < (2, 1):
# Known issue with variance calculation in pandas 2.0.x with pyarrow backend in groupby operations"
request.applymarker(pytest.mark.xfail)
Expand Down Expand Up @@ -166,11 +164,7 @@ def test_group_by_median(constructor: Constructor) -> None:
assert_equal_data(result, expected)


def test_group_by_n_unique_w_missing(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_group_by_n_unique_w_missing(constructor: Constructor) -> None:
data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]}
result = (
nw.from_native(constructor(data))
Expand Down
2 changes: 1 addition & 1 deletion tpch/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"dask": lambda x: x.compute(),
}

DUCKDB_XFAILS = ["q11", "q14", "q15", "q16", "q18", "q22"]
DUCKDB_XFAILS = ["q11", "q14", "q15", "q18", "q22"]

QUERY_DATA_PATH_MAP = {
"q1": (LINEITEM_PATH,),
Expand Down

0 comments on commit 36dacf9

Please sign in to comment.