Skip to content

Commit

Permalink
Merge branch 'main' of github.com:MarcoGorelli/narwhals into increase…
Browse files Browse the repository at this point in the history
…-coverage
  • Loading branch information
raisa committed Mar 17, 2024
2 parents 3ac19ab + 57b2373 commit 567aaf3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def join(
) -> Self:
return self._from_dataframe(
self._dataframe.join(
other._dataframe,
self._extract_native(other),
how=how,
left_on=left_on,
right_on=right_on,
Expand All @@ -148,7 +148,7 @@ def join(
class DataFrame(BaseFrame):
def __init__(
self,
df: T,
df: Any,
*,
implementation: str | None = None,
) -> None:
Expand Down
28 changes: 14 additions & 14 deletions tpch/q1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ruff: noqa
from typing import Any
from datetime import datetime
from narwhals import translate_frame
import narwhals as nw
import pandas as pd
import polars

Expand All @@ -10,33 +10,33 @@

def q1(df_raw: Any) -> Any:
var_1 = datetime(1998, 9, 2)
df, pl = translate_frame(df_raw, is_lazy=True)
df = nw.LazyFrame(df_raw)
result = (
df.filter(pl.col("l_shipdate") <= var_1)
df.filter(nw.col("l_shipdate") <= var_1)
.group_by(["l_returnflag", "l_linestatus"])
.agg(
[
pl.sum("l_quantity").alias("sum_qty"),
pl.sum("l_extendedprice").alias("sum_base_price"),
(pl.col("l_extendedprice") * (1 - pl.col("l_discount")))
nw.sum("l_quantity").alias("sum_qty"),
nw.sum("l_extendedprice").alias("sum_base_price"),
(nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
.sum()
.alias("sum_disc_price"),
(
pl.col("l_extendedprice")
* (1.0 - pl.col("l_discount"))
* (1.0 + pl.col("l_tax"))
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
)
.sum()
.alias("sum_charge"),
pl.mean("l_quantity").alias("avg_qty"),
pl.mean("l_extendedprice").alias("avg_price"),
pl.mean("l_discount").alias("avg_disc"),
pl.len().alias("count_order"),
nw.mean("l_quantity").alias("avg_qty"),
nw.mean("l_extendedprice").alias("avg_price"),
nw.mean("l_discount").alias("avg_disc"),
nw.len().alias("count_order"),
],
)
.sort(["l_returnflag", "l_linestatus"])
)
return result.collect().to_native()
return nw.to_native(result.collect())


# df = pd.read_parquet("../tpch-data/lineitem.parquet")
Expand Down

0 comments on commit 567aaf3

Please sign in to comment.