From 57b2373c501e5683893f66fed5ff94a85bc8fdba Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 17 Mar 2024 14:13:11 +0000 Subject: [PATCH] extra coverage --- narwhals/dataframe.py | 4 ++-- tpch/q1.py | 28 ++++++++++++++-------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 69b383b08..6e4b3185b 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -137,7 +137,7 @@ def join( ) -> Self: return self._from_dataframe( self._dataframe.join( - other._dataframe, + self._extract_native(other), how=how, left_on=left_on, right_on=right_on, @@ -148,7 +148,7 @@ def join( class DataFrame(BaseFrame): def __init__( self, - df: T, + df: Any, *, implementation: str | None = None, ) -> None: diff --git a/tpch/q1.py b/tpch/q1.py index 1365dcab8..b74a962d3 100644 --- a/tpch/q1.py +++ b/tpch/q1.py @@ -1,7 +1,7 @@ # ruff: noqa from typing import Any from datetime import datetime -from narwhals import translate_frame +import narwhals as nw import pandas as pd import polars @@ -10,33 +10,33 @@ def q1(df_raw: Any) -> Any: var_1 = datetime(1998, 9, 2) - df, pl = translate_frame(df_raw, is_lazy=True) + df = nw.LazyFrame(df_raw) result = ( - df.filter(pl.col("l_shipdate") <= var_1) + df.filter(nw.col("l_shipdate") <= var_1) .group_by(["l_returnflag", "l_linestatus"]) .agg( [ - pl.sum("l_quantity").alias("sum_qty"), - pl.sum("l_extendedprice").alias("sum_base_price"), - (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))) + nw.sum("l_quantity").alias("sum_qty"), + nw.sum("l_extendedprice").alias("sum_base_price"), + (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))) .sum() .alias("sum_disc_price"), ( - pl.col("l_extendedprice") - * (1.0 - pl.col("l_discount")) - * (1.0 + pl.col("l_tax")) + nw.col("l_extendedprice") + * (1.0 - nw.col("l_discount")) + * (1.0 + nw.col("l_tax")) ) .sum() .alias("sum_charge"), - pl.mean("l_quantity").alias("avg_qty"), - pl.mean("l_extendedprice").alias("avg_price"), - pl.mean("l_discount").alias("avg_disc"), - pl.len().alias("count_order"), + nw.mean("l_quantity").alias("avg_qty"), + nw.mean("l_extendedprice").alias("avg_price"), + nw.mean("l_discount").alias("avg_disc"), + nw.len().alias("count_order"), ], ) .sort(["l_returnflag", "l_linestatus"]) ) - return result.collect().to_native() + return nw.to_native(result.collect()) # df = pd.read_parquet("../tpch-data/lineitem.parquet")