Skip to content

Commit

Permalink
restore test
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 14, 2024
1 parent f5ec176 commit 8d4f77b
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 60 deletions.
8 changes: 8 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from narwhals.dataframe import DataFrame
from narwhals.expression import col
from narwhals.expression import len
from narwhals.expression import max
from narwhals.expression import mean
from narwhals.expression import min
from narwhals.expression import sum
from narwhals.translate import get_namespace
from narwhals.translate import to_native
from narwhals.translate import translate_any
Expand All @@ -27,5 +31,9 @@
"to_native",
"col",
"len",
"min",
"max",
"mean",
"sum",
"DataFrame",
]
16 changes: 16 additions & 0 deletions narwhals/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,19 @@ def col(col_name: str):

def len():
return NarwhalsExpr(lambda plx: plx.len())


def sum(col_name):
return NarwhalsExpr(lambda plx: plx.sum(col_name))


def mean(col_name):
return NarwhalsExpr(lambda plx: plx.mean(col_name))


def min(col_name):
return NarwhalsExpr(lambda plx: plx.min(col_name))


def max(col_name):
return NarwhalsExpr(lambda plx: plx.max(col_name))
121 changes: 61 additions & 60 deletions tests/tpch_q1_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import os
from datetime import datetime
from typing import Any
from unittest import mock

import polars
import pytest
Expand Down Expand Up @@ -71,63 +73,62 @@ def test_q1(df_raw: Any) -> None:
compare_dicts(result, expected)


# @pytest.mark.parametrize(
# "df_raw",
# [
# (polars.read_parquet("tests/data/lineitem.parquet").to_pandas()),
# ],
# )
# @mock.patch.dict(os.environ, {"NARWHALS_FORCE_GENERIC": "1"})
# def test_q1_w_pandas_agg_generic_path(df_raw: Any) -> None:
# var_1 = datetime(1998, 9, 2)
# df = translate_frame(df_raw, is_lazy=True)
# pl = get_namespace(df)
# query_result = (
# df.filter(pl.col("l_shipdate") <= var_1)
# .group_by(["l_returnflag", "l_linestatus"])
# .agg(
# [
# pl.sum("l_quantity").alias("sum_qty"),
# pl.sum("l_extendedprice").alias("sum_base_price"),
# (pl.col("l_extendedprice") * (1 - pl.col("l_discount")))
# .sum()
# .alias("sum_disc_price"),
# (
# pl.col("l_extendedprice")
# * (1.0 - pl.col("l_discount"))
# * (1.0 + pl.col("l_tax"))
# )
# .sum()
# .alias("sum_charge"),
# pl.mean("l_quantity").alias("avg_qty"),
# pl.mean("l_extendedprice").alias("avg_price"),
# pl.mean("l_discount").alias("avg_disc"),
# pl.len().alias("count_order"),
# ],
# )
# .sort(["l_returnflag", "l_linestatus"])
# )
# result = query_result.collect().to_dict(as_series=False)
# expected = {
# "l_returnflag": ["A", "N", "N", "R"],
# "l_linestatus": ["F", "F", "O", "F"],
# "sum_qty": [2109.0, 29.0, 3682.0, 1876.0],
# "sum_base_price": [3114026.44, 39824.83, 5517101.99, 2947892.16],
# "sum_disc_price": [2954950.8082, 39028.3334, 5205468.4852, 2816542.4816999994],
# "sum_charge": [
# 3092840.4194289995,
# 39808.900068,
# 5406966.873419,
# 2935797.8313019997,
# ],
# "avg_qty": [27.75, 29.0, 25.047619047619047, 26.422535211267604],
# "avg_price": [
# 40974.032105263155,
# 39824.83,
# 37531.30605442177,
# 41519.607887323946,
# ],
# "avg_disc": [0.05039473684210526, 0.02, 0.05537414965986395, 0.04507042253521127],
# "count_order": [76, 1, 147, 71],
# }
# compare_dicts(result, expected)
@pytest.mark.parametrize(
"df_raw",
[
(polars.read_parquet("tests/data/lineitem.parquet").to_pandas()),
],
)
@mock.patch.dict(os.environ, {"NARWHALS_FORCE_GENERIC": "1"})
def test_q1_w_pandas_agg_generic_path(df_raw: Any) -> None:
var_1 = datetime(1998, 9, 2)
df = nw.DataFrame(df_raw, is_lazy=True)
query_result = (
df.filter(nw.col("l_shipdate") <= var_1)
.group_by(["l_returnflag", "l_linestatus"])
.agg(
[
nw.sum("l_quantity").alias("sum_qty"),
nw.sum("l_extendedprice").alias("sum_base_price"),
(nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
.sum()
.alias("sum_disc_price"),
(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
)
.sum()
.alias("sum_charge"),
nw.mean("l_quantity").alias("avg_qty"),
nw.mean("l_extendedprice").alias("avg_price"),
nw.mean("l_discount").alias("avg_disc"),
nw.len().alias("count_order"),
],
)
.sort(["l_returnflag", "l_linestatus"])
)
result = query_result.collect().to_dict(as_series=False)
expected = {
"l_returnflag": ["A", "N", "N", "R"],
"l_linestatus": ["F", "F", "O", "F"],
"sum_qty": [2109.0, 29.0, 3682.0, 1876.0],
"sum_base_price": [3114026.44, 39824.83, 5517101.99, 2947892.16],
"sum_disc_price": [2954950.8082, 39028.3334, 5205468.4852, 2816542.4816999994],
"sum_charge": [
3092840.4194289995,
39808.900068,
5406966.873419,
2935797.8313019997,
],
"avg_qty": [27.75, 29.0, 25.047619047619047, 26.422535211267604],
"avg_price": [
40974.032105263155,
39824.83,
37531.30605442177,
41519.607887323946,
],
"avg_disc": [0.05039473684210526, 0.02, 0.05537414965986395, 0.04507042253521127],
"count_order": [76, 1, 147, 71],
}
compare_dicts(result, expected)

0 comments on commit 8d4f77b

Please sign in to comment.