diff --git a/narwhals/__init__.py b/narwhals/__init__.py index c274278cc..17d1f4200 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -15,6 +15,7 @@ from narwhals.expression import sum from narwhals.expression import sum_horizontal from narwhals.series import Series +from narwhals.translate import from_native from narwhals.translate import to_native __version__ = "0.6.7" @@ -26,6 +27,7 @@ "is_pandas", "get_implementation", "to_native", + "from_native", "all", "col", "len", diff --git a/narwhals/translate.py b/narwhals/translate.py index cc20193d8..d65be2dac 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING from typing import Any +from narwhals.dataframe import DataFrame +from narwhals.dataframe import LazyFrame from narwhals.dependencies import get_cudf from narwhals.dependencies import get_modin from narwhals.dependencies import get_pandas @@ -30,6 +32,29 @@ def to_native(obj: BaseFrame | Series) -> Any: raise TypeError(msg) +def from_native(df: Any) -> BaseFrame: + if (pl := get_polars()) is not None and isinstance(df, pl.DataFrame): + return DataFrame(df) + elif (pl := get_polars()) is not None and isinstance(df, pl.LazyFrame): + return LazyFrame(df) + elif ( + (pd := get_pandas()) is not None + and isinstance(df, pd.DataFrame) + or (mpd := get_modin()) is not None + and isinstance(df, mpd.DataFrame) + or (cudf := get_cudf()) is not None + and isinstance(df, cudf.DataFrame) + ): + return DataFrame(df) + elif hasattr(df, "__narwhals_dataframe__"): # pragma: no cover + return DataFrame(df.__narwhals_dataframe__()) + elif hasattr(df, "__narwhals_lazyframe__"): # pragma: no cover + return LazyFrame(df.__narwhals_lazyframe__()) + else: + msg = f"Expected pandas-like dataframe, Polars dataframe, or Polars lazyframe, got: {type(df)}" + raise TypeError(msg) + + __all__ = [ "get_pandas", "get_polars", diff --git a/tpch/q1.py b/tpch/q1.py index e634cec47..6657ba86d 100644 --- a/tpch/q1.py +++ b/tpch/q1.py @@ -1,4 +1,5 @@ # ruff: noqa +import polars as pl from typing import Any from datetime import datetime import narwhals as nw @@ -10,7 +11,7 @@ def q1(df_raw: Any) -> Any: var_1 = datetime(1998, 9, 2) - df = nw.LazyFrame(df_raw) + df = nw.from_native(df_raw) result = ( df.filter(nw.col("l_shipdate") <= var_1) .with_columns( @@ -36,10 +37,14 @@ def q1(df_raw: Any) -> Any: ) .sort(["l_returnflag", "l_linestatus"]) ) - return nw.to_native(result.collect()) + return nw.to_native(result) df = pd.read_parquet( "../tpch-data/s1/lineitem.parquet", dtype_backend="pyarrow", engine="pyarrow" ) print(q1(df)) +df = pl.read_parquet("../tpch-data/s1/lineitem.parquet") +print(q1(df)) +df = pl.scan_parquet("../tpch-data/s1/lineitem.parquet") +print(q1(df).collect())