diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 2406cbecf8..87ec52ef50 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -20,7 +20,6 @@ from narwhals.group_by import GroupBy from narwhals.series import Series from narwhals.typing import IntoExpr - from narwhals.typing import T class BaseFrame: @@ -208,7 +207,7 @@ def to_dict(self, *, as_series: bool = True) -> dict[str, Any]: class LazyFrame(BaseFrame): def __init__( self, - df: T, + df: Any, *, implementation: str | None = None, ) -> None: diff --git a/narwhals/pandas_like/dataframe.py b/narwhals/pandas_like/dataframe.py index 41babe4af7..32dc9ae753 100644 --- a/narwhals/pandas_like/dataframe.py +++ b/narwhals/pandas_like/dataframe.py @@ -103,7 +103,7 @@ def filter( # Safety: all_horizontal's expression only returns a single column. mask = expr._call(self)[0] _mask = validate_dataframe_comparand(mask) - return self._from_dataframe(self._dataframe.loc[_mask]) + return self._from_dataframe(self._dataframe[_mask]) def with_columns( self, diff --git a/narwhals/pandas_like/utils.py b/narwhals/pandas_like/utils.py index 7f0351003e..9c3d25ee43 100644 --- a/narwhals/pandas_like/utils.py +++ b/narwhals/pandas_like/utils.py @@ -60,12 +60,6 @@ def validate_dataframe_comparand(other: Any) -> Any: from narwhals.pandas_like.dataframe import PandasDataFrame from narwhals.pandas_like.series import PandasSeries - if isinstance(other, list) and len(other) > 1: - # e.g. `plx.all() + plx.all()` - msg = "Multi-output expressions are not supported in this context" - raise ValueError(msg) - if isinstance(other, list): - other = other[0] if isinstance(other, PandasDataFrame): return NotImplemented if isinstance(other, PandasSeries): @@ -73,6 +67,12 @@ def validate_dataframe_comparand(other: Any) -> Any: # broadcast return item(other._series) return other._series + if isinstance(other, list) and len(other) > 1: + # e.g. `plx.all() + plx.all()` + msg = "Multi-output expressions are not supported in this context" + raise ValueError(msg) + if isinstance(other, list): + other = other[0] return other diff --git a/tpch/q1.py b/tpch/q1.py index 965069548a..e634cec47b 100644 --- a/tpch/q1.py +++ b/tpch/q1.py @@ -39,9 +39,7 @@ def q1(df_raw: Any) -> Any: return nw.to_native(result.collect()) -df = pd.read_parquet("../tpch-data/s1/lineitem.parquet", dtype_backend="pyarrow") -breakpoint() -# df["l_shipdate"] = pd.to_datetime(df["l_shipdate"]) -print(q1(df)) -df = polars.scan_parquet("../tpch-data/s1/lineitem.parquet") +df = pd.read_parquet( + "../tpch-data/s1/lineitem.parquet", dtype_backend="pyarrow", engine="pyarrow" +) print(q1(df))