Skip to content

Commit

Permalink
test eager only for rank
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jan 7, 2025
1 parent 585d0d6 commit 6c72df7
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions tests/expr_and_series/rank_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import narwhals.stable.v1 as nw
from tests.utils import PANDAS_VERSION
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

Expand Down Expand Up @@ -37,15 +36,12 @@
@pytest.mark.parametrize("data", [data_int, data_float])
def test_rank_expr(
request: pytest.FixtureRequest,
constructor: Constructor,
constructor_eager: ConstructorEager,
method: Literal["average", "min", "max", "dense", "ordinal"],
data: dict[str, list[float]],
) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

if (
"pandas_pyarrow" in str(constructor)
"pandas_pyarrow" in str(constructor_eager)
and PANDAS_VERSION < (2, 1)
and isinstance(data["a"][0], int)
):
Expand All @@ -56,12 +52,12 @@ def test_rank_expr(
ValueError,
match=r"`rank` with `method='average' is not supported for pyarrow backend.",
)
if "pyarrow_table" in str(constructor) and method == "average"
if "pyarrow_table" in str(constructor_eager) and method == "average"
else does_not_raise()
)

with context:
df = nw.from_native(constructor(data))
df = nw.from_native(constructor_eager(data))

result = df.select(nw.col("a").rank(method=method))
expected_data = {"a": expected[method]}
Expand Down Expand Up @@ -103,28 +99,28 @@ def test_rank_series(
@pytest.mark.parametrize("method", rank_methods)
def test_rank_expr_in_over_context(
request: pytest.FixtureRequest,
constructor: Constructor,
constructor_eager: ConstructorEager,
method: Literal["average", "min", "max", "dense", "ordinal"],
) -> None:
if any(x in str(constructor) for x in ("pyarrow_table", "dask")):
if any(x in str(constructor_eager) for x in ("pyarrow_table", "dask")):
# Pyarrow raises:
# > pyarrow.lib.ArrowKeyError: No function registered with name: hash_rank
# We can handle that to provide a better error message.
request.applymarker(pytest.mark.xfail)

if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1):
if "pandas_pyarrow" in str(constructor_eager) and PANDAS_VERSION < (2, 1):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data_float))
df = nw.from_native(constructor_eager(data_float))

result = df.select(nw.col("a").rank(method=method).over("b"))
expected_data = {"a": expected_over[method]}
assert_equal_data(result, expected_data)


def test_invalid_method_raise(constructor: Constructor) -> None:
def test_invalid_method_raise(constructor_eager: ConstructorEager) -> None:
method = "invalid_method_name"
df = nw.from_native(constructor(data_float))
df = nw.from_native(constructor_eager(data_float))

msg = (
"Ranking method must be one of {'average', 'min', 'max', 'dense', 'ordinal'}. "
Expand Down

0 comments on commit 6c72df7

Please sign in to comment.