diff --git a/.github/workflows/check_tpch_queries.yml b/.github/workflows/check_tpch_queries.yml new file mode 100644 index 000000000..46dd5df20 --- /dev/null +++ b/.github/workflows/check_tpch_queries.yml @@ -0,0 +1,30 @@ +name: Tests for TPCH Queries + +on: + pull_request: + push: + branches: [main] + +jobs: + validate-queries: + strategy: + matrix: + python-version: ["3.12"] + os: [ubuntu-latest] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install uv + run: curl -LsSf https://astral.sh/uv/install.sh | sh + - name: install-reqs + run: uv pip install --upgrade -r requirements-dev.txt --system + - name: local-install + run: uv pip install -e . --system + - name: generate-data + run: cd tpch && python generate_data.py + - name: tpch-tests + run: cd tpch && pytest tests \ No newline at end of file diff --git a/.github/workflows/extremes.yml b/.github/workflows/extremes.yml index ae9c79009..7e1a5586e 100644 --- a/.github/workflows/extremes.yml +++ b/.github/workflows/extremes.yml @@ -104,7 +104,7 @@ jobs: - name: uninstall pandas run: uv pip uninstall pandas --system - name: install-pandas-nightly - run: uv pip install --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple pandas --system + run: uv pip install --prerelease=allow --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple pandas --system - name: uninstall numpy run: uv pip uninstall numpy --system - name: install numpy nightly diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 0458a4729..265442e9f 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -25,7 +25,7 @@ jobs: if: runner.os == 'Windows' run: powershell -c "irm https://astral.sh/uv/install.ps1 | iex" - name: install-reqs - run: uv pip install --upgrade tox virtualenv setuptools -r requirements-dev.txt --system + run: uv pip install --upgrade tox virtualenv setuptools -r requirements-dev.txt ibis-framework[duckdb] --system - name: show-deps run: uv pip freeze - name: Run pytest @@ -78,6 +78,11 @@ jobs: run: uv pip install --upgrade modin[dask] --system - name: show-deps run: uv pip freeze + - name: install ibis + run: uv pip install ibis-framework[duckdb] --system + # Ibis puts upper bounds on dependencies, and requires Python3.10+, + # which messes with other dependencies on lower Python versions + if: matrix.python-version == '3.12' - name: Run pytest run: pytest tests --cov=narwhals --cov=tests --cov-fail-under=100 --runslow - name: Run doctests diff --git a/.gitignore b/.gitignore index d6edf57e6..3825a68a6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,36 @@ -.venv +# Byte-compiled / optimized / DLL files +__pycache__/ *.pyc -todo.md + +# Distribution / packaging +dist/ + +# Unit test / coverage reports +.nox/ .coverage -site/ .coverage.* -.nox -*.lock +.cache +coverage.xml +.hypothesis/ +.pytest_cache/ +# Documentation +site/ +todo.md docs/api-completeness/*.md -!docs/api-completeness/index.md \ No newline at end of file +!docs/api-completeness/index.md + +# Lock files +*.lock + +# Environments +.venv + +# TPC-H data +tpch/data/* + +# VSCode +.vscode/ + +# MacOS +.DS_Store diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4914e7e16..f3a68e7a0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.5.7' + rev: 'v0.6.3' hooks: # Run the formatter. - id: ruff-format @@ -9,11 +9,11 @@ repos: - id: ruff args: [--fix] - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.11.1' + rev: 'v1.11.2' hooks: - id: mypy additional_dependencies: ['polars==1.4.1', 'pytest==8.3.2'] - exclude: utils|tpch + files: ^(narwhals|tests)/ - repo: https://github.com/codespell-project/codespell rev: 'v2.3.0' hooks: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d36d21a55..aeed2538f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -47,22 +47,41 @@ git clone git@github.com:YOUR-USERNAME/narwhals.git ### 4. Setting up your environment -Here's how you can set up your local development environment to contribute: - -1. Make sure you have Python3.8+ installed (for example, Python 3.11) -2. Create a new virtual environment with `python3.11 -m venv .venv` (or whichever version of Python3.9+ you prefer) -3. Activate it: `. .venv/bin/activate` -4. Install Narwhals: `pip install -e .` -5. Install test requirements: `pip install -r requirements-dev.txt` -6. Install docs requirements: `pip install -r docs/requirements-docs.txt` +Here's how you can set up your local development environment to contribute. + +#### Option 1: Use UV (recommended) + +1. Make sure you have Python3.8+ installed (for example, Python 3.11), create a virtual environment, + and activate it. If you're new to this, here's one way that we recommend: + 1. Install uv: https://github.com/astral-sh/uv?tab=readme-ov-file#getting-started + 2. Install some version of Python greater than Python3.8. For example, to install + Python3.11: + ``` + uv python install 3.11 + ``` + 3. Create a virtual environment: + ``` + uv venv -p 3.11 --seed + ``` + 4. Activate it. On Linux, this is `. .venv/bin/activate`, on Windows `.\.venv\Scripts\activate`. +2. Install Narwhals: `uv pip install -e .` +3. Install test requirements: `uv pip install -r requirements-dev.txt` +4. Install docs requirements: `uv pip install -r docs/requirements-docs.txt` You should also install pre-commit: ``` -pip install pre-commit +uv pip install pre-commit pre-commit install ``` This will automatically format and lint your code before each commit, and it will block the commit if any issues are found. +#### Option 2: use python3-venv + +1. Make sure you have Python 3.8+ installed. If you don't, you can check [install Python](https://realpython.com/installing-python/) + to learn how. Then, [create and activate](https://realpython.com/python-virtual-environments-a-primer/) + a virtual environment. +2. Then, follow steps 2-4 from above but using `pip install` instead of `uv pip install`. + ### 5. Working on your issue Create a new git branch from the `main` branch in your local repository. diff --git a/README.md b/README.md index d26107e67..74630fd03 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,31 @@ provided some funding / development time: If you contribute to Narwhals on your organization's time, please let us know. We'd be happy to add your employer to this list! +## Appears on + +Narwhals has been featured in several talks, podcasts, and blog posts: + +- [Talk Python to me Podcast](https://youtu.be/FSH7BZ0tuE0) + Ahoy, Narwhals are bridging the data science APIs + +- [Super Data Science: ML & AI Podcast](https://www.youtube.com/watch?v=TeG4U8R0U8U) + Narwhals: For Pandas-to-Polars DataFrame Compatibility + +- [Sample Space Podcast | probabl](https://youtu.be/8hYdq4sWbbQ?si=WG0QP1CZ6gkFf18b) + How Narwhals has many end users ... that never use it directly. - Marco Gorelli + +- [Pycon Lithuania](https://www.youtube.com/watch?v=-mdx7Cn6_6E) + Marco Gorelli - DataFrame interoperatiblity - what's been achieved, and what comes next? + +- [Pycon Italy](https://www.youtube.com/watch?v=3IqUli9XsmQ) + How you can write a dataframe-agnostic library - Marco Gorelli + +- [Polars Blog Post](https://pola.rs/posts/lightweight_plotting/) + Polars has a new lightweight plotting backend + +- [Quansight Labs blog post (w/ Scikit-Lego)](https://labs.quansight.org/blog/scikit-lego-narwhals) + How Narwhals and scikit-lego came together to achieve dataframe-agnosticism + ## Why "Narwhals"? [Coz they are so awesome](https://youtu.be/ykwqXuMPsoc?si=A-i8LdR38teYsos4). diff --git a/docs/api-reference/dataframe.md b/docs/api-reference/dataframe.md index c144b4af0..f78b4e3da 100644 --- a/docs/api-reference/dataframe.md +++ b/docs/api-reference/dataframe.md @@ -22,6 +22,7 @@ - item - iter_rows - join + - join_asof - lazy - null_count - pipe diff --git a/docs/api-reference/dependencies.md b/docs/api-reference/dependencies.md index 6c1a93d91..959e8ee0c 100644 --- a/docs/api-reference/dependencies.md +++ b/docs/api-reference/dependencies.md @@ -5,6 +5,7 @@ options: members: - get_cudf + - get_ibis - get_modin - get_pandas - get_polars @@ -12,6 +13,7 @@ - is_cudf_dataframe - is_cudf_series - is_dask_dataframe + - is_ibis_table - is_modin_dataframe - is_modin_series - is_numpy_array diff --git a/docs/api-reference/expr.md b/docs/api-reference/expr.md index cc1290a85..7188b2c36 100644 --- a/docs/api-reference/expr.md +++ b/docs/api-reference/expr.md @@ -30,6 +30,7 @@ - max - mean - min + - mode - null_count - n_unique - over diff --git a/docs/api-reference/lazyframe.md b/docs/api-reference/lazyframe.md index 9ca6a9745..5d472bab6 100644 --- a/docs/api-reference/lazyframe.md +++ b/docs/api-reference/lazyframe.md @@ -15,6 +15,7 @@ - group_by - head - join + - join_asof - lazy - pipe - rename diff --git a/docs/api-reference/series.md b/docs/api-reference/series.md index f9cc2e6bb..c016b566d 100644 --- a/docs/api-reference/series.md +++ b/docs/api-reference/series.md @@ -12,6 +12,7 @@ - any - arg_true - cast + - clip - count - cum_sum - diff @@ -22,7 +23,6 @@ - gather_every - head - is_between - - clip - is_duplicated - is_empty - is_first_distinct @@ -36,13 +36,15 @@ - max - mean - min + - mode - name - - null_count - n_unique + - null_count - pipe - quantile - round - sample + - scatter - shape - shift - sort diff --git a/docs/installation.md b/docs/installation.md index 04625c686..5a49dba8f 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -11,6 +11,6 @@ Then, if you start the Python REPL and see the following: ```python >>> import narwhals >>> narwhals.__version__ -'1.5.5' +'1.7.0' ``` then installation worked correctly! diff --git a/docs/why.md b/docs/why.md index adf8f39b4..4ec605d16 100644 --- a/docs/why.md +++ b/docs/why.md @@ -27,7 +27,7 @@ pl_df_right = pl.DataFrame({"a": [1, 2, 3], "c": [4, 5, 6]}) pl_left_merge = pl_df_left.join(pl_df_right, left_on="b", right_on="c", how="left") print(pd_left_merge.columns) -print(pl_df_right.columns) +print(pl_left_merge.columns) ``` There are several such subtle difference between the libraries. Writing dataframe-agnostic code is hard! diff --git a/narwhals/__init__.py b/narwhals/__init__.py index 0977e716a..d76ad2262 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -53,7 +53,7 @@ from narwhals.utils import maybe_get_index from narwhals.utils import maybe_set_index -__version__ = "1.5.5" +__version__ = "1.7.0" __all__ = [ "dependencies", diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 20e507166..f409ef735 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: import numpy as np + import pyarrow as pa from typing_extensions import Self from narwhals._arrow.group_by import ArrowGroupBy @@ -33,7 +34,7 @@ class ArrowDataFrame: # --- not in the spec --- def __init__( - self, native_dataframe: Any, *, backend_version: tuple[int, ...] + self, native_dataframe: pa.Table, *, backend_version: tuple[int, ...] ) -> None: self._native_frame = native_dataframe self._implementation = Implementation.PYARROW @@ -120,7 +121,12 @@ def __getitem__(self, item: str) -> ArrowSeries: ... def __getitem__(self, item: slice) -> ArrowDataFrame: ... def __getitem__( - self, item: str | slice | Sequence[int] | tuple[Sequence[int], str | int] + self, + item: str + | slice + | Sequence[int] + | Sequence[str] + | tuple[Sequence[int], str | int], ) -> ArrowSeries | ArrowDataFrame: if isinstance(item, str): from narwhals._arrow.series import ArrowSeries @@ -135,9 +141,12 @@ def __getitem__( and len(item) == 2 and isinstance(item[1], (list, tuple)) ): - return self._from_native_frame( - self._native_frame.take(item[0]).select(item[1]) - ) + if item[0] == slice(None): + selected_rows = self._native_frame + else: + selected_rows = self._native_frame.take(item[0]) + + return self._from_native_frame(selected_rows.select(item[1])) elif isinstance(item, tuple) and len(item) == 2: if isinstance(item[1], slice): @@ -187,6 +196,8 @@ def __getitem__( ) elif isinstance(item, Sequence) or (is_numpy_array(item) and item.ndim == 1): + if isinstance(item, Sequence) and all(isinstance(x, str) for x in item): + return self._from_native_frame(self._native_frame.select(item)) return self._from_native_frame(self._native_frame.take(item)) else: # pragma: no cover @@ -273,12 +284,8 @@ def join( how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner", left_on: str | list[str] | None, right_on: str | list[str] | None, + suffix: str, ) -> Self: - if isinstance(left_on, str): - left_on = [left_on] - if isinstance(right_on, str): - right_on = [right_on] - how_to_join_map = { "anti": "left anti", "semi": "left semi", @@ -299,7 +306,7 @@ def join( keys=key_token, right_keys=key_token, join_type="inner", - right_suffix="_right", + right_suffix=suffix, ) .drop([key_token]), ) @@ -310,10 +317,25 @@ def join( keys=left_on, right_keys=right_on, join_type=how_to_join_map[how], - right_suffix="_right", + right_suffix=suffix, ), ) + def join_asof( + self, + other: Self, + *, + left_on: str | None = None, + right_on: str | None = None, + on: str | None = None, + by_left: str | list[str] | None = None, + by_right: str | list[str] | None = None, + by: str | list[str] | None = None, + strategy: Literal["backward", "forward", "nearest"] = "backward", + ) -> Self: + msg = "join_asof is not yet supported on PyArrow tables" + raise NotImplementedError(msg) + def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001 to_drop = parse_columns_to_drop( compliant_frame=self, columns=columns, strict=strict diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index ca9293b8b..24e4fe5c5 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -15,6 +15,7 @@ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.namespace import ArrowNamespace from narwhals._arrow.series import ArrowSeries + from narwhals._arrow.typing import IntoArrowExpr from narwhals.dtypes import DType @@ -157,7 +158,7 @@ def __invert__(self) -> Self: def len(self) -> Self: return reuse_series_implementation(self, "len", returns_scalar=True) - def filter(self, *predicates: Any) -> Self: + def filter(self, *predicates: IntoArrowExpr) -> Self: plx = self.__narwhals_namespace__() expr = plx.all_horizontal(*predicates) return reuse_series_implementation(self, "filter", other=expr) @@ -228,7 +229,7 @@ def null_count(self) -> Self: def is_null(self) -> Self: return reuse_series_implementation(self, "is_null") - def is_between(self, lower_bound: Any, upper_bound: Any, closed: str) -> Any: + def is_between(self, lower_bound: Any, upper_bound: Any, closed: str) -> Self: return reuse_series_implementation( self, "is_between", lower_bound, upper_bound, closed ) @@ -308,7 +309,9 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: ) raise ValueError(msg) tmp = df.group_by(*keys).agg(self) - tmp = df.select(*keys).join(tmp, how="left", left_on=keys, right_on=keys) + tmp = df.select(*keys).join( + tmp, how="left", left_on=keys, right_on=keys, suffix="_right" + ) return [tmp[name] for name in self._output_names] return self.__class__( @@ -320,6 +323,9 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: backend_version=self._backend_version, ) + def mode(self: Self) -> Self: + return reuse_series_implementation(self, "mode") + @property def dt(self: Self) -> ArrowExprDateTimeNamespace: return ArrowExprDateTimeNamespace(self) diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 27c7ff368..6c7b20485 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -15,6 +15,26 @@ from narwhals._arrow.expr import ArrowExpr from narwhals._arrow.typing import IntoArrowExpr +POLARS_TO_ARROW_AGGREGATIONS = { + "len": "count", + "n_unique": "count_distinct", + "std": "stddev", + "var": "variance", # currently unused, we don't have `var` yet +} + + +def get_function_name_option(function_name: str) -> Any | None: + """Map specific pyarrow compute function to respective option to match polars behaviour.""" + import pyarrow.compute as pc # ignore-banned-import + + function_name_to_options = { + "count": pc.CountOptions(mode="all"), + "count_distinct": pc.CountOptions(mode="all"), + "stddev": pc.VarianceOptions(ddof=1), + "variance": pc.VarianceOptions(ddof=1), + } + return function_name_to_options.get(function_name) + class ArrowGroupBy: def __init__(self, df: ArrowDataFrame, keys: list[str]) -> None: @@ -112,17 +132,14 @@ def agg_arrow( raise AssertionError(msg) function_name = remove_prefix(expr._function_name, "col->") + function_name = POLARS_TO_ARROW_AGGREGATIONS.get(function_name, function_name) + + option = get_function_name_option(function_name) for root_name, output_name in zip(expr._root_names, expr._output_names): - if function_name != "len": - simple_aggregations[output_name] = ( - (root_name, function_name), - f"{root_name}_{function_name}", - ) - else: - simple_aggregations[output_name] = ( - (root_name, "count", pc.CountOptions(mode="all")), - f"{root_name}_count", - ) + simple_aggregations[output_name] = ( + (root_name, function_name, option), + f"{root_name}_{function_name}", + ) aggs: list[Any] = [] name_mapping = {} diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 1e9e4a08c..73390fdd3 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -18,15 +18,21 @@ from narwhals.utils import generate_unique_token if TYPE_CHECKING: + import pyarrow as pa from typing_extensions import Self from narwhals._arrow.dataframe import ArrowDataFrame + from narwhals._arrow.namespace import ArrowNamespace from narwhals.dtypes import DType class ArrowSeries: def __init__( - self, native_series: Any, *, name: str, backend_version: tuple[int, ...] + self, + native_series: pa.ChunkedArray, + *, + name: str, + backend_version: tuple[int, ...], ) -> None: self._name = name self._native_series = native_series @@ -60,6 +66,11 @@ def _from_iterable( backend_version=backend_version, ) + def __narwhals_namespace__(self) -> ArrowNamespace: + from narwhals._arrow.namespace import ArrowNamespace + + return ArrowNamespace(backend_version=self._backend_version) + def __len__(self) -> int: return len(self._native_series) @@ -310,8 +321,27 @@ def __getitem__(self, idx: slice | Sequence[int]) -> Self: ... def __getitem__(self, idx: int | slice | Sequence[int]) -> Any | Self: if isinstance(idx, int): return self._native_series[idx] + if isinstance(idx, Sequence): + return self._from_native_series(self._native_series.take(idx)) return self._from_native_series(self._native_series[idx]) + def scatter(self, indices: int | Sequence[int], values: Any) -> Self: + import numpy as np # ignore-banned-import + import pyarrow as pa # ignore-banned-import + import pyarrow.compute as pc # ignore-banned-import + + ca = self._native_series + mask = np.zeros(len(ca), dtype=bool) + mask[indices] = True + if isinstance(values, self.__class__): + values = validate_column_comparand(values) + if isinstance(values, pa.ChunkedArray): + values = values.combine_chunks() + if not isinstance(values, pa.Array): + values = pa.array(values) + result = pc.replace_with_mask(ca, mask, values.take(indices)) + return self._from_native_series(result) + def to_list(self) -> Any: return self._native_series.to_pylist() @@ -366,7 +396,9 @@ def all(self) -> bool: return pc.all(self._native_series) # type: ignore[no-any-return] - def is_between(self, lower_bound: Any, upper_bound: Any, closed: str = "both") -> Any: + def is_between( + self, lower_bound: Any, upper_bound: Any, closed: str = "both" + ) -> Self: import pyarrow.compute as pc # ignore-banned-import() ser = self._native_series @@ -657,9 +689,16 @@ def clip( return self._from_native_series(arr) - def to_arrow(self: Self) -> Any: + def to_arrow(self: Self) -> pa.Array: return self._native_series.combine_chunks() + def mode(self: Self) -> ArrowSeries: + plx = self.__narwhals_namespace__() + col_token = generate_unique_token(n_bytes=8, columns=[self.name]) + return self.value_counts(name=col_token, normalize=False).filter( + plx.col(col_token) == plx.col(col_token).max() + )[self.name] + @property def shape(self) -> tuple[int]: return (len(self._native_series),) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 6f7517aeb..b8294839c 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -51,7 +51,7 @@ def translate_dtype(dtype: Any) -> dtypes.DType: return dtypes.Duration() if pa.types.is_dictionary(dtype): return dtypes.Categorical() - raise AssertionError + return dtypes.Unknown() # pragma: no cover def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any: diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 9774d6c8e..180a897bd 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -18,9 +18,11 @@ from narwhals.utils import parse_version if TYPE_CHECKING: + import dask.dataframe as dd from typing_extensions import Self from narwhals._dask.expr import DaskExpr + from narwhals._dask.group_by import DaskLazyGroupBy from narwhals._dask.namespace import DaskNamespace from narwhals._dask.typing import IntoDaskExpr from narwhals.dtypes import DType @@ -28,7 +30,7 @@ class DaskLazyFrame: def __init__( - self, native_dataframe: Any, *, backend_version: tuple[int, ...] + self, native_dataframe: dd.DataFrame, *, backend_version: tuple[int, ...] ) -> None: self._native_frame = native_dataframe self._backend_version = backend_version @@ -77,14 +79,17 @@ def filter( and isinstance(predicates[0], list) and all(isinstance(x, bool) for x in predicates[0]) ): - mask = predicates[0] - else: - from narwhals._dask.namespace import DaskNamespace + msg = ( + "`LazyFrame.filter` is not supported for Dask backend with boolean masks." + ) + raise NotImplementedError(msg) - plx = DaskNamespace(backend_version=self._backend_version) - expr = plx.all_horizontal(*predicates) - # Safety: all_horizontal's expression only returns a single column. - mask = expr._call(self)[0] + from narwhals._dask.namespace import DaskNamespace + + plx = DaskNamespace(backend_version=self._backend_version) + expr = plx.all_horizontal(*predicates) + # Safety: all_horizontal's expression only returns a single column. + mask = expr._call(self)[0] return self._from_native_frame(self._native_frame.loc[mask]) def lazy(self) -> Self: @@ -206,12 +211,8 @@ def join( how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner", left_on: str | list[str] | None, right_on: str | list[str] | None, + suffix: str, ) -> Self: - if isinstance(left_on, str): - left_on = [left_on] - if isinstance(right_on, str): - right_on = [right_on] - if how == "cross": key_token = generate_unique_token( n_bytes=8, columns=[*self.columns, *other.columns] @@ -224,7 +225,7 @@ def join( how="inner", left_on=key_token, right_on=key_token, - suffixes=("", "_right"), + suffixes=("", suffix), ) .drop(columns=key_token), ) @@ -276,7 +277,7 @@ def join( how="left", left_on=left_on, right_on=right_on, - suffixes=("", "_right"), + suffixes=("", suffix), ) extra = [] for left_key, right_key in zip(left_on, right_on): # type: ignore[arg-type] @@ -292,17 +293,52 @@ def join( left_on=left_on, right_on=right_on, how=how, + suffixes=("", suffix), + ), + ) + + def join_asof( + self, + other: Self, + *, + left_on: str | None = None, + right_on: str | None = None, + on: str | None = None, + by_left: str | list[str] | None = None, + by_right: str | list[str] | None = None, + by: str | list[str] | None = None, + strategy: Literal["backward", "forward", "nearest"] = "backward", + ) -> Self: + plx = self.__native_namespace__() + return self._from_native_frame( + plx.merge_asof( + self._native_frame, + other._native_frame, + left_on=left_on, + right_on=right_on, + on=on, + left_by=by_left, + right_by=by_right, + by=by, + direction=strategy, suffixes=("", "_right"), ), ) - def group_by(self, *by: str) -> Any: + def group_by(self, *by: str) -> DaskLazyGroupBy: from narwhals._dask.group_by import DaskLazyGroupBy return DaskLazyGroupBy(self, list(by)) def tail(self: Self, n: int) -> Self: - return self._from_native_frame(self._native_frame.tail(n=n, compute=False)) + native_frame = self._native_frame + n_partitions = native_frame.npartitions + + if n_partitions == 1: + return self._from_native_frame(self._native_frame.tail(n=n, compute=False)) + else: + msg = "`LazyFrame.tail` is not supported for Dask backend with multiple partitions." + raise NotImplementedError(msg) def gather_every(self: Self, n: int, offset: int) -> Self: row_index_token = generate_unique_token(n_bytes=8, columns=self.columns) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 62aaa85e6..f08af590c 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -14,6 +14,7 @@ from narwhals.utils import generate_unique_token if TYPE_CHECKING: + import dask_expr from typing_extensions import Self from narwhals._dask.dataframe import DaskLazyFrame @@ -24,8 +25,7 @@ class DaskExpr: def __init__( self, - # callable from DaskLazyFrame to list of (native) Dask Series - call: Callable[[DaskLazyFrame], Any], + call: Callable[[DaskLazyFrame], list[dask_expr.Series]], *, depth: int, function_name: str, @@ -58,7 +58,7 @@ def from_column_names( *column_names: str, backend_version: tuple[int, ...], ) -> Self: - def func(df: DaskLazyFrame) -> list[Any]: + def func(df: DaskLazyFrame) -> list[dask_expr.Series]: return [df._native_frame.loc[:, column_name] for column_name in column_names] return cls( @@ -73,14 +73,14 @@ def func(df: DaskLazyFrame) -> list[Any]: def _from_call( self, - # callable from DaskLazyFrame to list of (native) Dask Series - call: Any, + # First argument to `call` should be `dask_expr.Series` + call: Callable[..., dask_expr.Series], expr_name: str, *args: Any, returns_scalar: bool, **kwargs: Any, ) -> Self: - def func(df: DaskLazyFrame) -> list[Any]: + def func(df: DaskLazyFrame) -> list[dask_expr.Series]: results = [] inputs = self._call(df) _args = [maybe_evaluate(df, x) for x in args] @@ -131,7 +131,7 @@ def func(df: DaskLazyFrame) -> list[Any]: ) def alias(self, name: str) -> Self: - def func(df: DaskLazyFrame) -> list[Any]: + def func(df: DaskLazyFrame) -> list[dask_expr.Series]: inputs = self._call(df) return [_input.rename(name) for _input in inputs] @@ -524,8 +524,15 @@ def quantile( interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], ) -> Self: if interpolation == "linear": + + def func(_input: dask_expr.Series, _quantile: float) -> dask_expr.Series: + if _input.npartitions > 1: + msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions." + raise NotImplementedError(msg) + return _input.quantile(q=_quantile, method="dask") + return self._from_call( - lambda _input, quantile: _input.quantile(q=quantile, method="dask"), + func, "quantile", quantile, returns_scalar=True, @@ -535,7 +542,7 @@ def quantile( raise NotImplementedError(msg) def is_first_distinct(self: Self) -> Self: - def func(_input: Any) -> Any: + def func(_input: dask_expr.Series) -> dask_expr.Series: _name = _input.name col_token = generate_unique_token(n_bytes=8, columns=[_name]) _input = add_row_index(_input.to_frame(), col_token) @@ -552,7 +559,7 @@ def func(_input: Any) -> Any: ) def is_last_distinct(self: Self) -> Self: - def func(_input: Any) -> Any: + def func(_input: dask_expr.Series) -> dask_expr.Series: _name = _input.name col_token = generate_unique_token(n_bytes=8, columns=[_name]) _input = add_row_index(_input.to_frame(), col_token) @@ -567,7 +574,7 @@ def func(_input: Any) -> Any: ) def is_duplicated(self: Self) -> Self: - def func(_input: Any) -> Any: + def func(_input: dask_expr.Series) -> dask_expr.Series: _name = _input.name return ( _input.to_frame().groupby(_name).transform("size", meta=(_name, int)) > 1 @@ -580,7 +587,7 @@ def func(_input: Any) -> Any: ) def is_unique(self: Self) -> Self: - def func(_input: Any) -> Any: + def func(_input: dask_expr.Series) -> dask_expr.Series: _name = _input.name return ( _input.to_frame().groupby(_name).transform("size", meta=(_name, int)) == 1 @@ -626,13 +633,18 @@ def func(df: DaskLazyFrame) -> list[Any]: "`nw.col('a', 'b')`\n" ) raise ValueError(msg) + + if df._native_frame.npartitions > 1: + msg = "`Expr.over` is not supported for Dask backend with multiple partitions." + raise NotImplementedError(msg) + tmp = df.group_by(*keys).agg(self) - tmp = ( + tmp_native = ( df.select(*keys) - .join(tmp, how="left", left_on=keys, right_on=keys) + .join(tmp, how="left", left_on=keys, right_on=keys, suffix="_right") ._native_frame ) - return [tmp[name] for name in self._output_names] + return [tmp_native[name] for name in self._output_names] return self.__class__( func, @@ -644,6 +656,10 @@ def func(df: DaskLazyFrame) -> list[Any]: backend_version=self._backend_version, ) + def mode(self: Self) -> Self: + msg = "`Expr.mode` is not supported for the Dask backend." + raise NotImplementedError(msg) + @property def str(self: Self) -> DaskExprStringNamespace: return DaskExprStringNamespace(self) diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 8538c62d2..d5fbaaf94 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -10,12 +10,33 @@ from narwhals.utils import remove_prefix if TYPE_CHECKING: + import dask.dataframe as dd + import pandas as pd + from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr from narwhals._dask.typing import IntoDaskExpr -POLARS_TO_PANDAS_AGGREGATIONS = { + +def n_unique() -> dd.Aggregation: + import dask.dataframe as dd # ignore-banned-import + + def chunk(s: pd.core.groupby.generic.SeriesGroupBy) -> int: + return s.nunique(dropna=False) # type: ignore[no-any-return] + + def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> int: + return s0.sum() # type: ignore[no-any-return] + + return dd.Aggregation( + name="nunique", + chunk=chunk, + agg=agg, + ) + + +POLARS_TO_DASK_AGGREGATIONS = { "len": "size", + "n_unique": n_unique, } @@ -51,6 +72,7 @@ def agg( output_names.extend(expr._output_names) return agg_dask( + self._df, self._grouped, exprs, self._keys, @@ -67,6 +89,7 @@ def _from_native_frame(self, df: DaskLazyFrame) -> DaskLazyFrame: def agg_dask( + df: DaskLazyFrame, grouped: Any, exprs: list[DaskExpr], keys: list[str], @@ -78,6 +101,10 @@ def agg_dask( - https://github.com/rapidsai/cudf/issues/15118 - https://github.com/rapidsai/cudf/issues/15084 """ + if not exprs: + # No aggregation provided + return df.select(*keys).unique(subset=keys) + all_simple_aggs = True for expr in exprs: if not is_simple_aggregation(expr): @@ -85,7 +112,7 @@ def agg_dask( break if all_simple_aggs: - simple_aggregations: dict[str, tuple[str, str]] = {} + simple_aggregations: dict[str, tuple[str, str | dd.Aggregation]] = {} for expr in exprs: if expr._depth == 0: # e.g. agg(nw.len()) # noqa: ERA001 @@ -93,7 +120,7 @@ def agg_dask( msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" raise AssertionError(msg) - function_name = POLARS_TO_PANDAS_AGGREGATIONS.get( + function_name = POLARS_TO_DASK_AGGREGATIONS.get( expr._function_name, expr._function_name ) for output_name in expr._output_names: @@ -108,9 +135,11 @@ def agg_dask( raise AssertionError(msg) function_name = remove_prefix(expr._function_name, "col->") - function_name = POLARS_TO_PANDAS_AGGREGATIONS.get( - function_name, function_name - ) + function_name = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name) + + # deal with n_unique case in a "lazy" mode to not depend on dask globally + function_name = function_name() if callable(function_name) else function_name + for root_name, output_name in zip(expr._root_names, expr._output_names): simple_aggregations[output_name] = (root_name, function_name) try: diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index e6019b509..1668ee323 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -4,20 +4,23 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Iterable from typing import NoReturn from typing import cast from narwhals import dtypes +from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr from narwhals._dask.selectors import DaskSelectorNamespace +from narwhals._dask.utils import reverse_translate_dtype from narwhals._dask.utils import validate_comparand from narwhals._expression_parsing import parse_into_exprs if TYPE_CHECKING: import dask_expr - from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.typing import IntoDaskExpr + from narwhals.dtypes import DType class DaskNamespace: @@ -69,10 +72,17 @@ def col(self, *column_names: str) -> DaskExpr: ) def lit(self, value: Any, dtype: dtypes.DType | None) -> DaskExpr: - # TODO @FBruzzesi: cast to dtype once `narwhals_to_native_dtype` is implemented. - # It should be enough to add `.astype(narwhals_to_native_dtype(dtype))` + def convert_if_dtype( + series: dask_expr.Series, dtype: DType | type[DType] + ) -> dask_expr.Series: + return series.astype(reverse_translate_dtype(dtype)) if dtype else series + return DaskExpr( - lambda df: [df._native_frame.assign(lit=value).loc[:, "lit"]], + lambda df: [ + df._native_frame.assign(lit=value) + .loc[:, "lit"] + .pipe(convert_if_dtype, dtype) + ], depth=0, function_name="lit", root_names=None, @@ -142,6 +152,47 @@ def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: [expr.fill_null(0) for expr in parse_into_exprs(*exprs, namespace=self)], ) + def concat( + self, + items: Iterable[DaskLazyFrame], + *, + how: str = "vertical", + ) -> DaskLazyFrame: + import dask.dataframe as dd # ignore-banned-import + + if len(list(items)) == 0: + msg = "No items to concatenate" # pragma: no cover + raise AssertionError(msg) + native_frames = [i._native_frame for i in items] + if how == "vertical": + if not all( + tuple(i.columns) == tuple(native_frames[0].columns) for i in native_frames + ): # pragma: no cover + msg = "unable to vstack with non-matching columns" + raise AssertionError(msg) + return DaskLazyFrame( + dd.concat(native_frames, axis=0, join="inner"), + backend_version=self._backend_version, + ) + if how == "horizontal": + all_column_names: list[str] = [ + column for frame in native_frames for column in frame.columns + ] + if len(all_column_names) != len(set(all_column_names)): # pragma: no cover + duplicates = [ + i for i in all_column_names if all_column_names.count(i) > 1 + ] + msg = ( + f"Columns with name(s): {', '.join(duplicates)} " + "have more than one occurrence" + ) + raise AssertionError(msg) + return DaskLazyFrame( + dd.concat(native_frames, axis=1, join="outer"), + backend_version=self._backend_version, + ) + raise NotImplementedError + def mean_horizontal(self, *exprs: IntoDaskExpr) -> IntoDaskExpr: dask_exprs = parse_into_exprs(*exprs, namespace=self) total = reduce(lambda x, y: x + y, (e.fill_null(0.0) for e in dask_exprs)) diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index 073b3abd8..d3525f71f 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -8,6 +8,7 @@ from narwhals._dask.expr import DaskExpr if TYPE_CHECKING: + import dask_expr from typing_extensions import Self from narwhals._dask.dataframe import DaskLazyFrame @@ -118,12 +119,10 @@ def call(df: DaskLazyFrame) -> list[Any]: def __or__(self: Self, other: DaskSelector | Any) -> DaskSelector | Any: if isinstance(other, DaskSelector): - def call(df: DaskLazyFrame) -> list[Any]: + def call(df: DaskLazyFrame) -> list[dask_expr.Series]: lhs = self._call(df) rhs = other._call(df) - return [ # type: ignore[no-any-return] - x for x in lhs if x.name not in [x.name for x in rhs] - ] + rhs + return [x for x in lhs if x.name not in [x.name for x in rhs]] + rhs return DaskSelector( call, diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index e7fb64d02..1f5cda4ba 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -9,6 +9,7 @@ from narwhals.utils import parse_version if TYPE_CHECKING: + import dask.dataframe as dd import dask_expr from narwhals._dask.dataframe import DaskLazyFrame @@ -34,7 +35,7 @@ def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any: def parse_exprs_and_named_exprs( df: DaskLazyFrame, *exprs: Any, **named_exprs: Any -) -> dict[str, Any]: +) -> dict[str, dask_expr.Series]: results = {} for expr in exprs: if hasattr(expr, "__narwhals_expr__"): @@ -62,7 +63,7 @@ def parse_exprs_and_named_exprs( return results -def add_row_index(frame: Any, name: str) -> Any: +def add_row_index(frame: dd.DataFrame, name: str) -> dd.DataFrame: frame = frame.assign(**{name: 1}) return frame.assign(**{name: frame[name].cumsum(method="blelloch") - 1}) diff --git a/tpch/__init__.py b/narwhals/_duckdb/__init__.py similarity index 100% rename from tpch/__init__.py rename to narwhals/_duckdb/__init__.py diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py new file mode 100644 index 000000000..8de244658 --- /dev/null +++ b/narwhals/_duckdb/dataframe.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Any + +from narwhals import dtypes + +if TYPE_CHECKING: + from narwhals._duckdb.series import DuckDBInterchangeSeries + + +def map_duckdb_dtype_to_narwhals_dtype( + duckdb_dtype: Any, +) -> dtypes.DType: + if duckdb_dtype == "BIGINT": + return dtypes.Int64() + if duckdb_dtype == "INTEGER": + return dtypes.Int32() + if duckdb_dtype == "SMALLINT": + return dtypes.Int16() + if duckdb_dtype == "TINYINT": + return dtypes.Int8() + if duckdb_dtype == "UBIGINT": + return dtypes.UInt64() + if duckdb_dtype == "UINTEGER": + return dtypes.UInt32() + if duckdb_dtype == "USMALLINT": + return dtypes.UInt16() + if duckdb_dtype == "UTINYINT": + return dtypes.UInt8() + if duckdb_dtype == "DOUBLE": + return dtypes.Float64() + if duckdb_dtype == "FLOAT": + return dtypes.Float32() + if duckdb_dtype == "VARCHAR": + return dtypes.String() + if duckdb_dtype == "DATE": + return dtypes.Date() + if duckdb_dtype == "TIMESTAMP": + return dtypes.Datetime() + if duckdb_dtype == "BOOLEAN": + return dtypes.Boolean() + if duckdb_dtype == "INTERVAL": + return dtypes.Duration() + return dtypes.Unknown() + + +class DuckDBInterchangeFrame: + def __init__(self, df: Any) -> None: + self._native_frame = df + + def __narwhals_dataframe__(self) -> Any: + return self + + def __getitem__(self, item: str) -> DuckDBInterchangeSeries: + from narwhals._duckdb.series import DuckDBInterchangeSeries + + return DuckDBInterchangeSeries(self._native_frame.select(item)) + + def __getattr__(self, attr: str) -> Any: + if attr == "schema": + return { + column_name: map_duckdb_dtype_to_narwhals_dtype(duckdb_dtype) + for column_name, duckdb_dtype in zip( + self._native_frame.columns, self._native_frame.types + ) + } + + msg = ( # pragma: no cover + f"Attribute {attr} is not supported for metadata-only dataframes.\n\n" + "If you would like to see this kind of object better supported in " + "Narwhals, please open a feature request " + "at https://github.com/narwhals-dev/narwhals/issues." + ) + raise NotImplementedError(msg) # pragma: no cover diff --git a/narwhals/_duckdb/series.py b/narwhals/_duckdb/series.py new file mode 100644 index 000000000..f19a6f76f --- /dev/null +++ b/narwhals/_duckdb/series.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import Any + +from narwhals._duckdb.dataframe import map_duckdb_dtype_to_narwhals_dtype + + +class DuckDBInterchangeSeries: + def __init__(self, df: Any) -> None: + self._native_series = df + + def __narwhals_series__(self) -> Any: + return self + + def __getattr__(self, attr: str) -> Any: + if attr == "dtype": + return map_duckdb_dtype_to_narwhals_dtype(self._native_series.types[0]) + msg = ( # pragma: no cover + f"Attribute {attr} is not supported for metadata-only dataframes.\n\n" + "If you would like to see this kind of object better supported in " + "Narwhals, please open a feature request " + "at https://github.com/narwhals-dev/narwhals/issues." + ) + raise NotImplementedError(msg) # pragma: no cover diff --git a/narwhals/_ibis/__init__.py b/narwhals/_ibis/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py new file mode 100644 index 000000000..e2baa4ec4 --- /dev/null +++ b/narwhals/_ibis/dataframe.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Any + +from narwhals import dtypes + +if TYPE_CHECKING: + from narwhals._ibis.series import IbisInterchangeSeries + + +def map_ibis_dtype_to_narwhals_dtype( + ibis_dtype: Any, +) -> dtypes.DType: + if ibis_dtype.is_int64(): + return dtypes.Int64() + if ibis_dtype.is_int32(): + return dtypes.Int32() + if ibis_dtype.is_int16(): + return dtypes.Int16() + if ibis_dtype.is_int8(): + return dtypes.Int8() + if ibis_dtype.is_uint64(): + return dtypes.UInt64() + if ibis_dtype.is_uint32(): + return dtypes.UInt32() + if ibis_dtype.is_uint16(): + return dtypes.UInt16() + if ibis_dtype.is_uint8(): + return dtypes.UInt8() + if ibis_dtype.is_boolean(): + return dtypes.Boolean() + if ibis_dtype.is_float64(): + return dtypes.Float64() + if ibis_dtype.is_float32(): + return dtypes.Float32() + if ibis_dtype.is_string(): + return dtypes.String() + if ibis_dtype.is_date(): + return dtypes.Date() + if ibis_dtype.is_timestamp(): + return dtypes.Datetime() + return dtypes.Unknown() # pragma: no cover + + +class IbisInterchangeFrame: + def __init__(self, df: Any) -> None: + self._native_frame = df + + def __narwhals_dataframe__(self) -> Any: + return self + + def __getitem__(self, item: str) -> IbisInterchangeSeries: + from narwhals._ibis.series import IbisInterchangeSeries + + return IbisInterchangeSeries(self._native_frame[item]) + + def __getattr__(self, attr: str) -> Any: + if attr == "schema": + return { + column_name: map_ibis_dtype_to_narwhals_dtype(ibis_dtype) + for column_name, ibis_dtype in self._native_frame.schema().items() + } + msg = ( + f"Attribute {attr} is not supported for metadata-only dataframes.\n\n" + "If you would like to see this kind of object better supported in " + "Narwhals, please open a feature request " + "at https://github.com/narwhals-dev/narwhals/issues." + ) + raise NotImplementedError(msg) diff --git a/narwhals/_ibis/series.py b/narwhals/_ibis/series.py new file mode 100644 index 000000000..73e3b6d47 --- /dev/null +++ b/narwhals/_ibis/series.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import Any + +from narwhals._ibis.dataframe import map_ibis_dtype_to_narwhals_dtype + + +class IbisInterchangeSeries: + def __init__(self, df: Any) -> None: + self._native_series = df + + def __narwhals_series__(self) -> Any: + return self + + def __getattr__(self, attr: str) -> Any: + if attr == "dtype": + return map_ibis_dtype_to_narwhals_dtype(self._native_series.type()) + msg = ( + f"Attribute {attr} is not supported for metadata-only dataframes.\n\n" + "If you would like to see this kind of object better supported in " + "Narwhals, please open a feature request " + "at https://github.com/narwhals-dev/narwhals/issues." + ) + raise NotImplementedError(msg) diff --git a/narwhals/_interchange/dataframe.py b/narwhals/_interchange/dataframe.py index bf1b17243..2e9775258 100644 --- a/narwhals/_interchange/dataframe.py +++ b/narwhals/_interchange/dataframe.py @@ -70,6 +70,7 @@ def map_interchange_dtype_to_narwhals_dtype( class InterchangeFrame: def __init__(self, df: Any) -> None: self._native_frame = df + self._interchange_frame = df.__dataframe__() def __narwhals_dataframe__(self) -> Any: return self @@ -77,15 +78,15 @@ def __narwhals_dataframe__(self) -> Any: def __getitem__(self, item: str) -> InterchangeSeries: from narwhals._interchange.series import InterchangeSeries - return InterchangeSeries(self._native_frame.get_column_by_name(item)) + return InterchangeSeries(self._interchange_frame.get_column_by_name(item)) @property def schema(self) -> dict[str, dtypes.DType]: return { column_name: map_interchange_dtype_to_narwhals_dtype( - self._native_frame.get_column_by_name(column_name).dtype + self._interchange_frame.get_column_by_name(column_name).dtype ) - for column_name in self._native_frame.column_names() + for column_name in self._interchange_frame.column_names() } def __getattr__(self, attr: str) -> NoReturn: diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 193955cbd..71a659998 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -111,13 +111,22 @@ def __getitem__(self, item: tuple[Sequence[int], str | int]) -> PandasLikeSeries def __getitem__(self, item: Sequence[int]) -> PandasLikeDataFrame: ... @overload - def __getitem__(self, item: str) -> PandasLikeSeries: ... + def __getitem__(self, item: str) -> PandasLikeSeries: ... # type: ignore[overload-overlap] + + @overload + def __getitem__(self, item: Sequence[str]) -> PandasLikeDataFrame: ... @overload def __getitem__(self, item: slice) -> PandasLikeDataFrame: ... def __getitem__( - self, item: str | slice | Sequence[int] | tuple[Sequence[int], str | int] + self, + item: str + | int + | slice + | Sequence[int] + | Sequence[str] + | tuple[Sequence[int], str | int], ) -> PandasLikeSeries | PandasLikeDataFrame: if isinstance(item, str): from narwhals._pandas_like.series import PandasLikeSeries @@ -174,7 +183,7 @@ def __getitem__( from narwhals._pandas_like.series import PandasLikeSeries if isinstance(item[1], str): - item = (item[0], self._native_frame.columns.get_loc(item[1])) + item = (item[0], self._native_frame.columns.get_loc(item[1])) # type: ignore[assignment] native_series = self._native_frame.iloc[item] elif isinstance(item[1], int): native_series = self._native_frame.iloc[item] @@ -191,6 +200,8 @@ def __getitem__( elif isinstance(item, (slice, Sequence)) or ( is_numpy_array(item) and item.ndim == 1 ): + if isinstance(item, Sequence) and all(isinstance(x, str) for x in item): + return self._from_native_frame(self._native_frame.loc[:, item]) return self._from_native_frame(self._native_frame.iloc[item]) else: # pragma: no cover @@ -403,12 +414,12 @@ def join( how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner", left_on: str | list[str] | None, right_on: str | list[str] | None, + suffix: str, ) -> Self: if isinstance(left_on, str): left_on = [left_on] if isinstance(right_on, str): right_on = [right_on] - if how == "cross": if ( self._implementation is Implementation.MODIN @@ -428,7 +439,7 @@ def join( how="inner", left_on=key_token, right_on=key_token, - suffixes=("", "_right"), + suffixes=("", suffix), ) .drop(columns=key_token), ) @@ -437,7 +448,7 @@ def join( self._native_frame.merge( other._native_frame, how="cross", - suffixes=("", "_right"), + suffixes=("", suffix), ), ) @@ -489,14 +500,14 @@ def join( how="left", left_on=left_on, right_on=right_on, - suffixes=("", "_right"), + suffixes=("", suffix), ) extra = [] for left_key, right_key in zip(left_on, right_on): # type: ignore[arg-type] if right_key != left_key and right_key not in self.columns: extra.append(right_key) elif right_key != left_key: - extra.append(f"{right_key}_right") + extra.append(f"{right_key}{suffix}") return self._from_native_frame(result_native.drop(columns=extra)) return self._from_native_frame( @@ -505,6 +516,34 @@ def join( left_on=left_on, right_on=right_on, how=how, + suffixes=("", suffix), + ), + ) + + def join_asof( + self, + other: Self, + *, + left_on: str | None = None, + right_on: str | None = None, + on: str | None = None, + by_left: str | list[str] | None = None, + by_right: str | list[str] | None = None, + by: str | list[str] | None = None, + strategy: Literal["backward", "forward", "nearest"] = "backward", + ) -> Self: + plx = self.__native_namespace__() + return self._from_native_frame( + plx.merge_asof( + self._native_frame, + other._native_frame, + left_on=left_on, + right_on=right_on, + on=on, + left_by=by_left, + right_by=by_right, + by=by, + direction=strategy, suffixes=("", "_right"), ), ) @@ -562,8 +601,8 @@ def to_numpy(self, dtype: Any = None, copy: bool | None = None) -> Any: from narwhals._pandas_like.series import PANDAS_TO_NUMPY_DTYPE_MISSING if copy is None: - # pandas default differs from Polars - copy = False + # pandas default differs from Polars, but cuDF default is True + copy = self._implementation is Implementation.CUDF if dtype is not None: return self._native_frame.to_numpy(dtype=dtype, copy=copy) @@ -649,8 +688,7 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: def to_arrow(self: Self) -> Any: if self._implementation is Implementation.CUDF: # pragma: no cover - msg = "`to_arrow` is not implemented for CuDF backend." - raise NotImplementedError(msg) + return self._native_frame.to_arrow(preserve_index=False) import pyarrow as pa # ignore-banned-import() diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 44154453d..74a2ee31d 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -287,7 +287,9 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: ) raise ValueError(msg) tmp = df.group_by(*keys).agg(self) - tmp = df.select(*keys).join(tmp, how="left", left_on=keys, right_on=keys) + tmp = df.select(*keys).join( + tmp, how="left", left_on=keys, right_on=keys, suffix="_right" + ) return [tmp[name] for name in self._output_names] return self.__class__( @@ -336,6 +338,9 @@ def len(self: Self) -> Self: def gather_every(self: Self, n: int, offset: int = 0) -> Self: return reuse_series_implementation(self, "gather_every", n=n, offset=offset) + def mode(self: Self) -> Self: + return reuse_series_implementation(self, "mode") + @property def str(self: Self) -> PandasLikeExprStringNamespace: return PandasLikeExprStringNamespace(self) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 11abc85c8..892291d57 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -21,6 +21,7 @@ POLARS_TO_PANDAS_AGGREGATIONS = { "len": "size", + "n_unique": "nunique", } @@ -79,6 +80,7 @@ def agg( dataframe_is_empty=self._df._native_frame.empty, implementation=implementation, backend_version=self._df._backend_version, + native_namespace=self._df.__native_namespace__(), ) def _from_native_frame(self, df: PandasLikeDataFrame) -> PandasLikeDataFrame: @@ -103,7 +105,7 @@ def __iter__(self) -> Iterator[tuple[Any, PandasLikeDataFrame]]: yield from ((key, self._from_native_frame(sub_df)) for (key, sub_df) in iterator) -def agg_pandas( +def agg_pandas( # noqa: PLR0915 grouped: Any, exprs: list[PandasLikeExpr], keys: list[str], @@ -113,6 +115,7 @@ def agg_pandas( implementation: Any, backend_version: tuple[int, ...], dataframe_is_empty: bool, + native_namespace: Any, ) -> PandasLikeDataFrame: """ This should be the fastpath, but cuDF is too far behind to use it. @@ -120,13 +123,18 @@ def agg_pandas( - https://github.com/rapidsai/cudf/issues/15118 - https://github.com/rapidsai/cudf/issues/15084 """ - all_simple_aggs = True + all_aggs_are_simple = True for expr in exprs: if not is_simple_aggregation(expr): - all_simple_aggs = False + all_aggs_are_simple = False break - if all_simple_aggs: + # dict of {output_name: root_name} that we count n_unique on + # We need to do this separately from the rest so that we + # can pass the `dropna` kwargs. + nunique_aggs: dict[str, str] = {} + + if all_aggs_are_simple: simple_aggregations: dict[str, tuple[str, str]] = {} for expr in exprs: if expr._depth == 0: @@ -154,21 +162,54 @@ def agg_pandas( function_name, function_name ) for root_name, output_name in zip(expr._root_names, expr._output_names): - simple_aggregations[output_name] = (root_name, function_name) + if function_name == "nunique": + nunique_aggs[output_name] = root_name + else: + simple_aggregations[output_name] = (root_name, function_name) - aggs = collections.defaultdict(list) + simple_aggs = collections.defaultdict(list) name_mapping = {} for output_name, named_agg in simple_aggregations.items(): - aggs[named_agg[0]].append(named_agg[1]) + simple_aggs[named_agg[0]].append(named_agg[1]) name_mapping[f"{named_agg[0]}_{named_agg[1]}"] = output_name - try: - result_simple = grouped.agg(aggs) - except AttributeError as exc: - msg = "Failed to aggregated - does your aggregation function return a scalar?" - raise RuntimeError(msg) from exc - result_simple.columns = [f"{a}_{b}" for a, b in result_simple.columns] - result_simple = result_simple.rename(columns=name_mapping).reset_index() - return from_dataframe(result_simple.loc[:, output_names]) + if simple_aggs: + try: + result_simple_aggs = grouped.agg(simple_aggs) + except AttributeError as exc: + msg = "Failed to aggregated - does your aggregation function return a scalar?" + raise RuntimeError(msg) from exc + result_simple_aggs.columns = [ + f"{a}_{b}" for a, b in result_simple_aggs.columns + ] + result_simple_aggs = result_simple_aggs.rename( + columns=name_mapping + ).reset_index() + if nunique_aggs: + result_nunique_aggs = grouped[list(nunique_aggs.values())].nunique( + dropna=False + ) + result_nunique_aggs.columns = list(nunique_aggs.keys()) + result_nunique_aggs = result_nunique_aggs.reset_index() + if simple_aggs and nunique_aggs: + if ( + set(result_simple_aggs.columns) + .difference(keys) + .intersection(result_nunique_aggs.columns) + ): + msg = ( + "Got two aggregations with the same output name. Please make sure " + "that aggregations have unique output names." + ) + raise ValueError(msg) + result_aggs = result_simple_aggs.merge(result_nunique_aggs, on=keys) + elif nunique_aggs and not simple_aggs: + result_aggs = result_nunique_aggs + elif simple_aggs and not nunique_aggs: + result_aggs = result_simple_aggs + else: + # No aggregation provided + result_aggs = native_namespace.DataFrame(grouped.groups.keys(), columns=keys) + return from_dataframe(result_aggs.loc[:, output_names]) if dataframe_is_empty: # Don't even attempt this, it's way too inconsistent across pandas versions. diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index e94c95a8c..8288be263 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -10,6 +10,7 @@ from narwhals._pandas_like.utils import int_dtype_mapper from narwhals._pandas_like.utils import narwhals_to_native_dtype from narwhals._pandas_like.utils import native_series_from_iterable +from narwhals._pandas_like.utils import set_axis from narwhals._pandas_like.utils import to_datetime from narwhals._pandas_like.utils import translate_dtype from narwhals._pandas_like.utils import validate_column_comparand @@ -167,6 +168,22 @@ def shape(self) -> tuple[int]: def dtype(self) -> DType: return translate_dtype(self._native_series) + def scatter(self, indices: int | Sequence[int], values: Any) -> Self: + if isinstance(values, self.__class__): + # .copy() is necessary in some pre-2.2 versions of pandas to avoid + # `values` also getting modified (!) + values = validate_column_comparand(self._native_series.index, values).copy() + values = set_axis( + values, + self._native_series.index[indices], + implementation=self._implementation, + backend_version=self._backend_version, + ) + s = self._native_series + s.iloc[indices] = values + s.name = self.name + return self._from_native_series(s) + def cast( self, dtype: Any, @@ -473,7 +490,7 @@ def __array__(self, dtype: Any = None, copy: bool | None = None) -> Any: def to_numpy(self, dtype: Any = None, copy: bool | None = None) -> Any: # the default is meant to be None, but pandas doesn't allow it? # https://numpy.org/doc/stable/reference/generated/numpy.ndarray.__array__.html - copy = copy or False + copy = copy or self._implementation is Implementation.CUDF has_missing = self._native_series.isna().any() if ( @@ -635,13 +652,18 @@ def clip( def to_arrow(self: Self) -> Any: if self._implementation is Implementation.CUDF: # pragma: no cover - msg = "`to_arrow` is not implemented for CuDF backend." - raise NotImplementedError(msg) + return self._native_series.to_arrow() import pyarrow as pa # ignore-banned-import() return pa.Array.from_pandas(self._native_series) + def mode(self: Self) -> Self: + native_series = self._native_series + result = native_series.mode() + result.name = native_series.name + return self._from_native_series(result) + @property def str(self) -> PandasLikeSeriesStringNamespace: return PandasLikeSeriesStringNamespace(self) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 64df5913f..1b91f0910 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -182,27 +182,45 @@ def sort( def join( self, other: Self, - *, + on: str | list[str] | None = None, how: Literal["inner", "left", "cross", "semi", "anti"] = "inner", + *, left_on: str | list[str] | None = None, right_on: str | list[str] | None = None, + suffix: str = "_right", ) -> Self: _supported_joins = ("inner", "left", "cross", "anti", "semi") if how not in _supported_joins: - msg = f"Only the following join stragies are supported: {_supported_joins}; found '{how}'." + msg = f"Only the following join strategies are supported: {_supported_joins}; found '{how}'." raise NotImplementedError(msg) - if how == "cross" and (left_on or right_on): - msg = "Can not pass left_on, right_on for cross join" + if how == "cross" and ( + left_on is not None or right_on is not None or on is not None + ): + msg = "Can not pass `left_on`, `right_on` or `on` keys for cross join" + raise ValueError(msg) + + if how != "cross" and (on is None and (left_on is None or right_on is None)): + msg = f"Either (`left_on` and `right_on`) or `on` keys should be specified for {how}." raise ValueError(msg) + if how != "cross" and ( + on is not None and (left_on is not None or right_on is not None) + ): + msg = f"If `on` is specified, `left_on` and `right_on` should be None for {how}." + raise ValueError(msg) + + if on is not None: + left_on = right_on = on + return self._from_compliant_dataframe( self._compliant_frame.join( self._extract_compliant(other), how=how, left_on=left_on, right_on=right_on, + suffix=suffix, ) ) @@ -214,6 +232,64 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: self._compliant_frame.gather_every(n=n, offset=offset) ) + def join_asof( + self, + other: Self, + *, + left_on: str | None = None, + right_on: str | None = None, + on: str | None = None, + by_left: str | list[str] | None = None, + by_right: str | list[str] | None = None, + by: str | list[str] | None = None, + strategy: Literal["backward", "forward", "nearest"] = "backward", + ) -> Self: + _supported_strategies = ("backward", "forward", "nearest") + + if strategy not in _supported_strategies: + msg = f"Only the following strategies are supported: {_supported_strategies}; found '{strategy}'." + raise NotImplementedError(msg) + + if (on is None) and (left_on is None or right_on is None): + msg = "Either (`left_on` and `right_on`) or `on` keys should be specified." + raise ValueError(msg) + if (on is not None) and (left_on is not None or right_on is not None): + msg = "If `on` is specified, `left_on` and `right_on` should be None." + raise ValueError(msg) + if (by is None) and ( + (by_left is None and by_right is not None) + or (by_left is not None and by_right is None) + ): + msg = ( + "Can not specify only `by_left` or `by_right`, you need to specify both." + ) + raise ValueError(msg) + if (by is not None) and (by_left is not None or by_right is not None): + msg = "If `by` is specified, `by_left` and `by_right` should be None." + raise ValueError(msg) + if on is not None: + return self._from_compliant_dataframe( + self._compliant_frame.join_asof( + self._extract_compliant(other), + on=on, + by_left=by_left, + by_right=by_right, + by=by, + strategy=strategy, + ) + ) + return self._from_compliant_dataframe( + self._compliant_frame.join_asof( + self._extract_compliant(other), + left_on=left_on, + right_on=right_on, + by_left=by_left, + by_right=by_right, + by=by, + strategy=strategy, + ) + ) + class DataFrame(BaseFrame[FrameT]): """ @@ -350,7 +426,7 @@ def to_pandas(self) -> pd.DataFrame: def write_csv(self, file: str | Path | BytesIO | None = None) -> Any: r""" - Write dataframe to parquet file. + Write dataframe to comma-separated values (CSV) file. Examples: Construct pandas and Polars DataFrames: @@ -522,17 +598,24 @@ def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ... @overload def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ... @overload + def __getitem__(self, item: tuple[slice, Sequence[int]]) -> Self: ... + @overload def __getitem__(self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap] @overload def __getitem__(self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ... @overload + def __getitem__(self, item: tuple[slice, Sequence[str]]) -> Self: ... + @overload def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap] @overload def __getitem__(self, item: Sequence[int]) -> Self: ... @overload - def __getitem__(self, item: str) -> Series: ... + def __getitem__(self, item: str) -> Series: ... # type: ignore[overload-overlap] + + @overload + def __getitem__(self, item: Sequence[str]) -> Self: ... @overload def __getitem__(self, item: slice) -> Self: ... @@ -542,8 +625,9 @@ def __getitem__( item: str | slice | Sequence[int] + | Sequence[str] | tuple[Sequence[int], str | int] - | tuple[Sequence[int], Sequence[int] | Sequence[str] | slice], + | tuple[slice | Sequence[int], Sequence[int] | Sequence[str] | slice], ) -> Series | Self: """ Extract column or slice of DataFrame. @@ -560,6 +644,12 @@ def __getitem__( a `Series`. - `df[[0, 1], [0, 1, 2]]` extracts the first two rows and the first three columns and returns a `DataFrame` + - `df[:, [0, 1, 2]]` extracts all rows from the first three columns and returns a + `DataFrame`. + - `df[:, ['a', 'c']]` extracts all rows and columns `'a'` and `'c'` and returns a + `DataFrame`. + - `df[['a', 'c']]` extracts all rows and columns `'a'` and `'c'` and returns a + `DataFrame`. - `df[0: 2, ['a', 'c']]` extracts the first two rows and columns `'a'` and `'c'` and returns a `DataFrame` - `df[:, 0: 2]` extracts all rows from the first two columns and returns a `DataFrame` @@ -1768,27 +1858,29 @@ def sort( def join( self, other: Self, - *, + on: str | list[str] | None = None, how: Literal["inner", "left", "cross", "semi", "anti"] = "inner", + *, left_on: str | list[str] | None = None, right_on: str | list[str] | None = None, + suffix: str = "_right", ) -> Self: r""" Join in SQL-like fashion. Arguments: - other: DataFrame to join with. - + other: Lazy DataFrame to join with. + on: Name(s) of the join columns in both DataFrames. If set, `left_on` and + `right_on` should be None. how: Join strategy. * *inner*: Returns rows that have matching values in both tables. * *cross*: Returns the Cartesian product of rows from both tables. * *semi*: Filter rows that have a match in the right table. * *anti*: Filter rows that do not have a match in the right table. - - left_on: Name(s) of the left join column(s). - - right_on: Name(s) of the right join column(s). + left_on: Join column of the left DataFrame. + right_on: Join column of the right DataFrame. + suffix: Suffix to append to columns with a duplicate name. Returns: A new joined DataFrame @@ -1837,7 +1929,195 @@ def join( │ 2 ┆ 7.0 ┆ b ┆ y │ └─────┴─────┴─────┴───────┘ """ - return super().join(other, how=how, left_on=left_on, right_on=right_on) + return super().join( + other, how=how, left_on=left_on, right_on=right_on, on=on, suffix=suffix + ) + + def join_asof( + self, + other: Self, + *, + left_on: str | None = None, + right_on: str | None = None, + on: str | None = None, + by_left: str | list[str] | None = None, + by_right: str | list[str] | None = None, + by: str | list[str] | None = None, + strategy: Literal["backward", "forward", "nearest"] = "backward", + ) -> Self: + """ + Perform an asof join. + + This is similar to a left-join except that we match on nearest key rather than equal keys. + + Both DataFrames must be sorted by the asof_join key. + + Arguments: + other: DataFrame to join with. + + left_on: Name(s) of the left join column(s). + + right_on: Name(s) of the right join column(s). + + on: Join column of both DataFrames. If set, left_on and right_on should be None. + + by_left: join on these columns before doing asof join + + by_right: join on these columns before doing asof join + + by: join on these columns before doing asof join + + strategy: Join strategy. The default is "backward". + + * *backward*: selects the last row in the right DataFrame whose "on" key is less than or equal to the left's key. + * *forward*: selects the first row in the right DataFrame whose "on" key is greater than or equal to the left's key. + * *nearest*: search selects the last row in the right DataFrame whose value is nearest to the left's key. + + Returns: + A new joined DataFrame + + Examples: + >>> from datetime import datetime + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> data_gdp = { + ... "datetime": [ + ... datetime(2016, 1, 1), + ... datetime(2017, 1, 1), + ... datetime(2018, 1, 1), + ... datetime(2019, 1, 1), + ... datetime(2020, 1, 1), + ... ], + ... "gdp": [4164, 4411, 4566, 4696, 4827], + ... } + >>> data_population = { + ... "datetime": [ + ... datetime(2016, 3, 1), + ... datetime(2018, 8, 1), + ... datetime(2019, 1, 1), + ... ], + ... "population": [82.19, 82.66, 83.12], + ... } + >>> gdp_pd = pd.DataFrame(data_gdp) + >>> population_pd = pd.DataFrame(data_population) + + >>> gdp_pl = pl.DataFrame(data_gdp).sort("datetime") + >>> population_pl = pl.DataFrame(data_population).sort("datetime") + + Let's define a dataframe-agnostic function in which we join over "datetime" column: + + >>> @nw.narwhalify + ... def join_asof_datetime(df, other_any, strategy): + ... return df.join_asof(other_any, on="datetime", strategy=strategy) + + We can now pass either pandas or Polars to the function: + + >>> join_asof_datetime(population_pd, gdp_pd, strategy="backward") + datetime population gdp + 0 2016-03-01 82.19 4164 + 1 2018-08-01 82.66 4566 + 2 2019-01-01 83.12 4696 + + >>> join_asof_datetime(population_pl, gdp_pl, strategy="backward") + shape: (3, 3) + ┌─────────────────────┬────────────┬──────┐ + │ datetime ┆ population ┆ gdp │ + │ --- ┆ --- ┆ --- │ + │ datetime[μs] ┆ f64 ┆ i64 │ + ╞═════════════════════╪════════════╪══════╡ + │ 2016-03-01 00:00:00 ┆ 82.19 ┆ 4164 │ + │ 2018-08-01 00:00:00 ┆ 82.66 ┆ 4566 │ + │ 2019-01-01 00:00:00 ┆ 83.12 ┆ 4696 │ + └─────────────────────┴────────────┴──────┘ + + Here is a real-world times-series example that uses `by` argument. + + >>> from datetime import datetime + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> data_quotes = { + ... "datetime": [ + ... datetime(2016, 5, 25, 13, 30, 0, 23), + ... datetime(2016, 5, 25, 13, 30, 0, 23), + ... datetime(2016, 5, 25, 13, 30, 0, 30), + ... datetime(2016, 5, 25, 13, 30, 0, 41), + ... datetime(2016, 5, 25, 13, 30, 0, 48), + ... datetime(2016, 5, 25, 13, 30, 0, 49), + ... datetime(2016, 5, 25, 13, 30, 0, 72), + ... datetime(2016, 5, 25, 13, 30, 0, 75), + ... ], + ... "ticker": [ + ... "GOOG", + ... "MSFT", + ... "MSFT", + ... "MSFT", + ... "GOOG", + ... "AAPL", + ... "GOOG", + ... "MSFT", + ... ], + ... "bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01], + ... "ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03], + ... } + >>> data_trades = { + ... "datetime": [ + ... datetime(2016, 5, 25, 13, 30, 0, 23), + ... datetime(2016, 5, 25, 13, 30, 0, 38), + ... datetime(2016, 5, 25, 13, 30, 0, 48), + ... datetime(2016, 5, 25, 13, 30, 0, 48), + ... datetime(2016, 5, 25, 13, 30, 0, 48), + ... ], + ... "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], + ... "price": [51.95, 51.95, 720.77, 720.92, 98.0], + ... "quantity": [75, 155, 100, 100, 100], + ... } + >>> quotes_pd = pd.DataFrame(data_quotes) + >>> trades_pd = pd.DataFrame(data_trades) + >>> quotes_pl = pl.DataFrame(data_quotes).sort("datetime") + >>> trades_pl = pl.DataFrame(data_trades).sort("datetime") + + Let's define a dataframe-agnostic function in which we join over "datetime" and by "ticker" columns: + + >>> @nw.narwhalify + ... def join_asof_datetime_by_ticker(df, other_any): + ... return df.join_asof(other_any, on="datetime", by="ticker") + + We can now pass either pandas or Polars to the function: + + >>> join_asof_datetime_by_ticker(trades_pd, quotes_pd) + datetime ticker price quantity bid ask + 0 2016-05-25 13:30:00.000023 MSFT 51.95 75 51.95 51.96 + 1 2016-05-25 13:30:00.000038 MSFT 51.95 155 51.97 51.98 + 2 2016-05-25 13:30:00.000048 GOOG 720.77 100 720.50 720.93 + 3 2016-05-25 13:30:00.000048 GOOG 720.92 100 720.50 720.93 + 4 2016-05-25 13:30:00.000048 AAPL 98.00 100 NaN NaN + + >>> join_asof_datetime_by_ticker(trades_pl, quotes_pl) + shape: (5, 6) + ┌────────────────────────────┬────────┬────────┬──────────┬───────┬────────┐ + │ datetime ┆ ticker ┆ price ┆ quantity ┆ bid ┆ ask │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ datetime[μs] ┆ str ┆ f64 ┆ i64 ┆ f64 ┆ f64 │ + ╞════════════════════════════╪════════╪════════╪══════════╪═══════╪════════╡ + │ 2016-05-25 13:30:00.000023 ┆ MSFT ┆ 51.95 ┆ 75 ┆ 51.95 ┆ 51.96 │ + │ 2016-05-25 13:30:00.000038 ┆ MSFT ┆ 51.95 ┆ 155 ┆ 51.97 ┆ 51.98 │ + │ 2016-05-25 13:30:00.000048 ┆ GOOG ┆ 720.77 ┆ 100 ┆ 720.5 ┆ 720.93 │ + │ 2016-05-25 13:30:00.000048 ┆ GOOG ┆ 720.92 ┆ 100 ┆ 720.5 ┆ 720.93 │ + │ 2016-05-25 13:30:00.000048 ┆ AAPL ┆ 98.0 ┆ 100 ┆ null ┆ null │ + └────────────────────────────┴────────┴────────┴──────────┴───────┴────────┘ + """ + return super().join_asof( + other, + left_on=left_on, + right_on=right_on, + on=on, + by_left=by_left, + by_right=by_right, + by=by, + strategy=strategy, + ) # --- descriptive --- def is_duplicated(self: Self) -> Series: @@ -3307,27 +3587,29 @@ def sort( def join( self, other: Self, - *, + on: str | list[str] | None = None, how: Literal["inner", "left", "cross", "semi", "anti"] = "inner", + *, left_on: str | list[str] | None = None, right_on: str | list[str] | None = None, + suffix: str = "_right", ) -> Self: r""" Add a join operation to the Logical Plan. Arguments: other: Lazy DataFrame to join with. - + on: Name(s) of the join columns in both DataFrames. If set, `left_on` and + `right_on` should be None. how: Join strategy. * *inner*: Returns rows that have matching values in both tables. * *cross*: Returns the Cartesian product of rows from both tables. * *semi*: Filter rows that have a match in the right table. * *anti*: Filter rows that do not have a match in the right table. - left_on: Join column of the left DataFrame. - right_on: Join column of the right DataFrame. + suffix: Suffix to append to columns with a duplicate name. Returns: A new joined LazyFrame @@ -3376,7 +3658,194 @@ def join( │ 2 ┆ 7.0 ┆ b ┆ y │ └─────┴─────┴─────┴───────┘ """ - return super().join(other, how=how, left_on=left_on, right_on=right_on) + return super().join( + other, how=how, left_on=left_on, right_on=right_on, on=on, suffix=suffix + ) + + def join_asof( + self, + other: Self, + *, + left_on: str | None = None, + right_on: str | None = None, + on: str | None = None, + by_left: str | list[str] | None = None, + by_right: str | list[str] | None = None, + by: str | list[str] | None = None, + strategy: Literal["backward", "forward", "nearest"] = "backward", + ) -> Self: + """ + Perform an asof join. + + This is similar to a left-join except that we match on nearest key rather than equal keys. + + Both DataFrames must be sorted by the asof_join key. + + Arguments: + other: DataFrame to join with. + + left_on: Name(s) of the left join column(s). + + right_on: Name(s) of the right join column(s). + + on: Join column of both DataFrames. If set, left_on and right_on should be None. + + by_left: join on these columns before doing asof join + + by_right: join on these columns before doing asof join + + by: join on these columns before doing asof join + + strategy: Join strategy. The default is "backward". + + * *backward*: selects the last row in the right DataFrame whose "on" key is less than or equal to the left's key. + * *forward*: selects the first row in the right DataFrame whose "on" key is greater than or equal to the left's key. + * *nearest*: search selects the last row in the right DataFrame whose value is nearest to the left's key. + + Returns: + A new joined DataFrame + + Examples: + >>> from datetime import datetime + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> data_gdp = { + ... "datetime": [ + ... datetime(2016, 1, 1), + ... datetime(2017, 1, 1), + ... datetime(2018, 1, 1), + ... datetime(2019, 1, 1), + ... datetime(2020, 1, 1), + ... ], + ... "gdp": [4164, 4411, 4566, 4696, 4827], + ... } + >>> data_population = { + ... "datetime": [ + ... datetime(2016, 3, 1), + ... datetime(2018, 8, 1), + ... datetime(2019, 1, 1), + ... ], + ... "population": [82.19, 82.66, 83.12], + ... } + >>> gdp_pd = pd.DataFrame(data_gdp) + >>> population_pd = pd.DataFrame(data_population) + >>> gdp_pl = pl.LazyFrame(data_gdp).sort("datetime") + >>> population_pl = pl.LazyFrame(data_population).sort("datetime") + + Let's define a dataframe-agnostic function in which we join over "datetime" column: + + >>> @nw.narwhalify + ... def join_asof_datetime(df, other_any, strategy): + ... return df.join_asof(other_any, on="datetime", strategy=strategy) + + We can now pass either pandas or Polars to the function: + + >>> join_asof_datetime(population_pd, gdp_pd, strategy="backward") + datetime population gdp + 0 2016-03-01 82.19 4164 + 1 2018-08-01 82.66 4566 + 2 2019-01-01 83.12 4696 + + >>> join_asof_datetime(population_pl, gdp_pl, strategy="backward").collect() + shape: (3, 3) + ┌─────────────────────┬────────────┬──────┐ + │ datetime ┆ population ┆ gdp │ + │ --- ┆ --- ┆ --- │ + │ datetime[μs] ┆ f64 ┆ i64 │ + ╞═════════════════════╪════════════╪══════╡ + │ 2016-03-01 00:00:00 ┆ 82.19 ┆ 4164 │ + │ 2018-08-01 00:00:00 ┆ 82.66 ┆ 4566 │ + │ 2019-01-01 00:00:00 ┆ 83.12 ┆ 4696 │ + └─────────────────────┴────────────┴──────┘ + + Here is a real-world times-series example that uses `by` argument. + + >>> from datetime import datetime + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> data_quotes = { + ... "datetime": [ + ... datetime(2016, 5, 25, 13, 30, 0, 23), + ... datetime(2016, 5, 25, 13, 30, 0, 23), + ... datetime(2016, 5, 25, 13, 30, 0, 30), + ... datetime(2016, 5, 25, 13, 30, 0, 41), + ... datetime(2016, 5, 25, 13, 30, 0, 48), + ... datetime(2016, 5, 25, 13, 30, 0, 49), + ... datetime(2016, 5, 25, 13, 30, 0, 72), + ... datetime(2016, 5, 25, 13, 30, 0, 75), + ... ], + ... "ticker": [ + ... "GOOG", + ... "MSFT", + ... "MSFT", + ... "MSFT", + ... "GOOG", + ... "AAPL", + ... "GOOG", + ... "MSFT", + ... ], + ... "bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01], + ... "ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03], + ... } + >>> data_trades = { + ... "datetime": [ + ... datetime(2016, 5, 25, 13, 30, 0, 23), + ... datetime(2016, 5, 25, 13, 30, 0, 38), + ... datetime(2016, 5, 25, 13, 30, 0, 48), + ... datetime(2016, 5, 25, 13, 30, 0, 48), + ... datetime(2016, 5, 25, 13, 30, 0, 48), + ... ], + ... "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], + ... "price": [51.95, 51.95, 720.77, 720.92, 98.0], + ... "quantity": [75, 155, 100, 100, 100], + ... } + >>> quotes_pd = pd.DataFrame(data_quotes) + >>> trades_pd = pd.DataFrame(data_trades) + >>> quotes_pl = pl.LazyFrame(data_quotes).sort("datetime") + >>> trades_pl = pl.LazyFrame(data_trades).sort("datetime") + + Let's define a dataframe-agnostic function in which we join over "datetime" and by "ticker" columns: + + >>> @nw.narwhalify + ... def join_asof_datetime_by_ticker(df, other_any): + ... return df.join_asof(other_any, on="datetime", by="ticker") + + We can now pass either pandas or Polars to the function: + + >>> join_asof_datetime_by_ticker(trades_pd, quotes_pd) + datetime ticker price quantity bid ask + 0 2016-05-25 13:30:00.000023 MSFT 51.95 75 51.95 51.96 + 1 2016-05-25 13:30:00.000038 MSFT 51.95 155 51.97 51.98 + 2 2016-05-25 13:30:00.000048 GOOG 720.77 100 720.50 720.93 + 3 2016-05-25 13:30:00.000048 GOOG 720.92 100 720.50 720.93 + 4 2016-05-25 13:30:00.000048 AAPL 98.00 100 NaN NaN + + >>> join_asof_datetime_by_ticker(trades_pl, quotes_pl).collect() + shape: (5, 6) + ┌────────────────────────────┬────────┬────────┬──────────┬───────┬────────┐ + │ datetime ┆ ticker ┆ price ┆ quantity ┆ bid ┆ ask │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ datetime[μs] ┆ str ┆ f64 ┆ i64 ┆ f64 ┆ f64 │ + ╞════════════════════════════╪════════╪════════╪══════════╪═══════╪════════╡ + │ 2016-05-25 13:30:00.000023 ┆ MSFT ┆ 51.95 ┆ 75 ┆ 51.95 ┆ 51.96 │ + │ 2016-05-25 13:30:00.000038 ┆ MSFT ┆ 51.95 ┆ 155 ┆ 51.97 ┆ 51.98 │ + │ 2016-05-25 13:30:00.000048 ┆ GOOG ┆ 720.77 ┆ 100 ┆ 720.5 ┆ 720.93 │ + │ 2016-05-25 13:30:00.000048 ┆ GOOG ┆ 720.92 ┆ 100 ┆ 720.5 ┆ 720.93 │ + │ 2016-05-25 13:30:00.000048 ┆ AAPL ┆ 98.0 ┆ 100 ┆ null ┆ null │ + └────────────────────────────┴────────┴────────┴──────────┴───────┴────────┘ + """ + return super().join_asof( + other, + left_on=left_on, + right_on=right_on, + on=on, + by_left=by_left, + by_right=by_right, + by=by, + strategy=strategy, + ) def clone(self) -> Self: r""" diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index 66516eac9..2cd9f0983 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -16,6 +16,8 @@ from typing_extensions import TypeGuard import cudf import dask.dataframe as dd + import duckdb + import ibis import modin.pandas as mpd import pandas as pd import polars as pl @@ -64,74 +66,96 @@ def get_dask_dataframe() -> Any: return sys.modules.get("dask.dataframe", None) +def get_duckdb() -> Any: + """Get duckdb module (if already imported - else return None).""" + return sys.modules.get("duckdb", None) + + def get_dask_expr() -> Any: """Get dask_expr module (if already imported - else return None).""" return sys.modules.get("dask_expr", None) +def get_ibis() -> Any: + """Get ibis module (if already imported - else return None).""" + return sys.modules.get("ibis", None) + + def is_pandas_dataframe(df: Any) -> TypeGuard[pd.DataFrame]: """Check whether `df` is a pandas DataFrame without importing pandas.""" - return bool((pd := get_pandas()) is not None and isinstance(df, pd.DataFrame)) + return (pd := get_pandas()) is not None and isinstance(df, pd.DataFrame) def is_pandas_series(ser: Any) -> TypeGuard[pd.Series[Any]]: """Check whether `ser` is a pandas Series without importing pandas.""" - return bool((pd := get_pandas()) is not None and isinstance(ser, pd.Series)) + return (pd := get_pandas()) is not None and isinstance(ser, pd.Series) def is_modin_dataframe(df: Any) -> TypeGuard[mpd.DataFrame]: """Check whether `df` is a modin DataFrame without importing modin.""" - return bool((pd := get_modin()) is not None and isinstance(df, pd.DataFrame)) + return (pd := get_modin()) is not None and isinstance(df, pd.DataFrame) def is_modin_series(ser: Any) -> TypeGuard[mpd.Series]: """Check whether `ser` is a modin Series without importing modin.""" - return bool((pd := get_modin()) is not None and isinstance(ser, pd.Series)) + return (pd := get_modin()) is not None and isinstance(ser, pd.Series) def is_cudf_dataframe(df: Any) -> TypeGuard[cudf.DataFrame]: """Check whether `df` is a cudf DataFrame without importing cudf.""" - return bool((pd := get_cudf()) is not None and isinstance(df, pd.DataFrame)) + return (pd := get_cudf()) is not None and isinstance(df, pd.DataFrame) def is_cudf_series(ser: Any) -> TypeGuard[pd.Series[Any]]: """Check whether `ser` is a cudf Series without importing cudf.""" - return bool((pd := get_cudf()) is not None and isinstance(ser, pd.Series)) + return (pd := get_cudf()) is not None and isinstance(ser, pd.Series) def is_dask_dataframe(df: Any) -> TypeGuard[dd.DataFrame]: """Check whether `df` is a Dask DataFrame without importing Dask.""" - return bool((dd := get_dask_dataframe()) is not None and isinstance(df, dd.DataFrame)) + return (dd := get_dask_dataframe()) is not None and isinstance(df, dd.DataFrame) + + +def is_duckdb_relation(df: Any) -> TypeGuard[duckdb.DuckDBPyRelation]: + """Check whether `df` is a DuckDB Relation without importing DuckDB.""" + return (duckdb := get_duckdb()) is not None and isinstance( + df, duckdb.DuckDBPyRelation + ) + + +def is_ibis_table(df: Any) -> TypeGuard[ibis.Table]: + """Check whether `df` is a Ibis Table without importing Ibis.""" + return (ibis := get_ibis()) is not None and isinstance(df, ibis.expr.types.Table) def is_polars_dataframe(df: Any) -> TypeGuard[pl.DataFrame]: """Check whether `df` is a Polars DataFrame without importing Polars.""" - return bool((pl := get_polars()) is not None and isinstance(df, pl.DataFrame)) + return (pl := get_polars()) is not None and isinstance(df, pl.DataFrame) def is_polars_lazyframe(df: Any) -> TypeGuard[pl.LazyFrame]: """Check whether `df` is a Polars LazyFrame without importing Polars.""" - return bool((pl := get_polars()) is not None and isinstance(df, pl.LazyFrame)) + return (pl := get_polars()) is not None and isinstance(df, pl.LazyFrame) def is_polars_series(ser: Any) -> TypeGuard[pl.Series]: """Check whether `ser` is a Polars Series without importing Polars.""" - return bool((pl := get_polars()) is not None and isinstance(ser, pl.Series)) + return (pl := get_polars()) is not None and isinstance(ser, pl.Series) def is_pyarrow_chunked_array(ser: Any) -> TypeGuard[pa.ChunkedArray]: """Check whether `ser` is a PyArrow ChunkedArray without importing PyArrow.""" - return bool((pa := get_pyarrow()) is not None and isinstance(ser, pa.ChunkedArray)) + return (pa := get_pyarrow()) is not None and isinstance(ser, pa.ChunkedArray) def is_pyarrow_table(df: Any) -> TypeGuard[pa.Table]: """Check whether `df` is a PyArrow Table without importing PyArrow.""" - return bool((pa := get_pyarrow()) is not None and isinstance(df, pa.Table)) + return (pa := get_pyarrow()) is not None and isinstance(df, pa.Table) def is_numpy_array(arr: Any) -> TypeGuard[np.ndarray]: """Check whether `arr` is a NumPy Array without importing NumPy.""" - return bool((np := get_numpy()) is not None and isinstance(arr, np.ndarray)) + return (np := get_numpy()) is not None and isinstance(arr, np.ndarray) def is_pandas_like_dataframe(df: Any) -> bool: @@ -159,6 +183,8 @@ def is_pandas_like_series(arr: Any) -> bool: "get_cudf", "get_pyarrow", "get_numpy", + "get_ibis", + "is_ibis_table", "is_pandas_dataframe", "is_pandas_series", "is_polars_dataframe", diff --git a/narwhals/expr.py b/narwhals/expr.py index d8acadd20..5c5ff7d2e 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -1585,9 +1585,8 @@ def head(self, n: int = 10) -> Self: r""" Get the first `n` rows. - Arguments - n : int - Number of rows to return. + Arguments: + n: Number of rows to return. Examples: >>> import narwhals as nw @@ -1628,9 +1627,8 @@ def tail(self, n: int = 10) -> Self: r""" Get the last `n` rows. - Arguments - n : int - Number of rows to return. + Arguments: + n: Number of rows to return. Examples: >>> import narwhals as nw @@ -1819,7 +1817,7 @@ def clip( >>> import pandas as pd >>> import polars as pl >>> import narwhals as nw - >>> + >>> s = [1, 2, 3] >>> df_pd = pd.DataFrame({"s": s}) >>> df_pl = pl.DataFrame({"s": s}) @@ -1913,6 +1911,49 @@ def clip( """ return self.__class__(lambda plx: self._call(plx).clip(lower_bound, upper_bound)) + def mode(self: Self) -> Self: + r"""Compute the most occurring value(s). + + Can return multiple values. + + Examples: + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals as nw + + >>> data = { + ... "a": [1, 1, 2, 3], + ... "b": [1, 1, 2, 2], + ... } + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.DataFrame(data) + + We define a library agnostic function: + + >>> @nw.narwhalify + ... def func(df): + ... return df.select(nw.col("a", "b").mode()).sort("a", "b") + + We can then pass either pandas or Polars to `func`: + + >>> func(df_pd) + a b + 0 1 1 + 1 1 2 + + >>> func(df_pl) + shape: (2, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 1 ┆ 1 │ + │ 1 ┆ 2 │ + └─────┴─────┘ + """ + return self.__class__(lambda plx: self._call(plx).mode()) + @property def str(self: Self) -> ExprStringNamespace: return ExprStringNamespace(self) diff --git a/narwhals/functions.py b/narwhals/functions.py index 51193c6c0..430705e66 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -33,6 +33,105 @@ def concat( *, how: Literal["horizontal", "vertical"] = "vertical", ) -> FrameT: + """ + Concatenate multiple DataFrames, LazyFrames into a single entity. + + Arguments: + items: DataFrames, LazyFrames to concatenate. + + how: {'vertical', 'horizontal'} + * vertical: Stacks Series from DataFrames vertically and fills with `null` + if the lengths don't match. + * horizontal: Stacks Series from DataFrames horizontally and fills with `null` + if the lengths don't match. + + Returns: + A new DataFrame, Lazyframe resulting from the concatenation. + + Raises: + NotImplementedError: The items to concatenate should either all be eager, or all lazy + + Examples: + + Let's take an example of vertical concatenation: + + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals as nw + >>> data_1 = {"a": [1, 2, 3], "b": [4, 5, 6]} + >>> data_2 = {"a": [5, 2], "b": [1, 4]} + + >>> df_pd_1 = pd.DataFrame(data_1) + >>> df_pd_2 = pd.DataFrame(data_2) + >>> df_pl_1 = pl.DataFrame(data_1) + >>> df_pl_2 = pl.DataFrame(data_2) + + Let's define a dataframe-agnostic function: + + >>> @nw.narwhalify + ... def func(df1, df2): + ... return nw.concat([df1, df2], how="vertical") + + >>> func(df_pd_1, df_pd_2) + a b + 0 1 4 + 1 2 5 + 2 3 6 + 0 5 1 + 1 2 4 + >>> func(df_pl_1, df_pl_2) + shape: (5, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 1 ┆ 4 │ + │ 2 ┆ 5 │ + │ 3 ┆ 6 │ + │ 5 ┆ 1 │ + │ 2 ┆ 4 │ + └─────┴─────┘ + + Let's look at case a for horizontal concatenation: + + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals as nw + >>> data_1 = {"a": [1, 2, 3], "b": [4, 5, 6]} + >>> data_2 = {"c": [5, 2], "d": [1, 4]} + + >>> df_pd_1 = pd.DataFrame(data_1) + >>> df_pd_2 = pd.DataFrame(data_2) + >>> df_pl_1 = pl.DataFrame(data_1) + >>> df_pl_2 = pl.DataFrame(data_2) + + Defining a dataframe-agnostic function: + + >>> @nw.narwhalify + ... def func(df1, df2): + ... return nw.concat([df1, df2], how="horizontal") + + >>> func(df_pd_1, df_pd_2) + a b c d + 0 1 4 5.0 1.0 + 1 2 5 2.0 4.0 + 2 3 6 NaN NaN + + >>> func(df_pl_1, df_pl_2) + shape: (3, 4) + ┌─────┬─────┬──────┬──────┐ + │ a ┆ b ┆ c ┆ d │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞═════╪═════╪══════╪══════╡ + │ 1 ┆ 4 ┆ 5 ┆ 1 │ + │ 2 ┆ 5 ┆ 2 ┆ 4 │ + │ 3 ┆ 6 ┆ null ┆ null │ + └─────┴─────┴──────┴──────┘ + + """ + if how not in ("horizontal", "vertical"): # pragma: no cover msg = "Only horizontal and vertical concatenations are supported" raise NotImplementedError(msg) @@ -167,13 +266,13 @@ def from_dict( >>> import narwhals as nw >>> data = {"a": [1, 2, 3], "b": [4, 5, 6]} - Let's define a dataframe-agnostic function: + Let's create a new dataframe of the same class as the dataframe we started with, from a dict of new data: >>> @nw.narwhalify ... def func(df): - ... data = {"c": [5, 2], "d": [1, 4]} + ... new_data = {"c": [5, 2], "d": [1, 4]} ... native_namespace = nw.get_native_namespace(df) - ... return nw.from_dict(data, native_namespace=native_namespace) + ... return nw.from_dict(new_data, native_namespace=native_namespace) Let's see what happens when passing pandas / Polars input: diff --git a/narwhals/series.py b/narwhals/series.py index d80564d22..9fcb07a23 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -85,6 +85,58 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: ca = pa.chunked_array([self.to_arrow()]) return ca.__arrow_c_stream__(requested_schema=requested_schema) + def scatter(self, indices: int | Sequence[int], values: Any) -> Self: + """ + Set value(s) at given position(s). + + Arguments: + indices: Position(s) to set items at. + values: Values to set. + + Warning: + For some libraries (pandas, Polars), this method operates in-place, + whereas for others (PyArrow) it doesn't! + We recommend being careful with it, and not relying on the + in-placeness. For example, a valid use case is when updating + a column in an eager dataframe, see the example below. + + Examples: + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals as nw + >>> data = {"a": [1, 2, 3], "b": [4, 5, 6]} + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.DataFrame(data) + + We define a library agnostic function: + + >>> @nw.narwhalify + ... def func(df): + ... return df.with_columns(df["a"].scatter([0, 1], [999, 888])) + + We can then pass either pandas or Polars to `func`: + + >>> func(df_pd) + a b + 0 999 4 + 1 888 5 + 2 3 6 + >>> func(df_pl) + shape: (3, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 999 ┆ 4 │ + │ 888 ┆ 5 │ + │ 3 ┆ 6 │ + └─────┴─────┘ + """ + return self._from_compliant_series( + self._compliant_series.scatter(indices, self._extract_native(values)) + ) + @property def shape(self) -> tuple[int]: """ @@ -783,12 +835,9 @@ def drop_nulls(self) -> Self: """ Drop all null values. - See Also: - drop_nans - Notes: - A null value is not the same as a NaN value. - To drop NaN values, use :func:`drop_nans`. + pandas and Polars handle null values differently. Polars distinguishes + between NaN and Null, whereas pandas doesn't. Examples: >>> import pandas as pd @@ -2006,9 +2055,8 @@ def head(self: Self, n: int = 10) -> Self: r""" Get the first `n` rows. - Arguments - n : int - Number of rows to return. + Arguments: + n: Number of rows to return. Examples: >>> import narwhals as nw @@ -2047,9 +2095,8 @@ def tail(self: Self, n: int = 10) -> Self: r""" Get the last `n` rows. - Arguments - n : int - Number of rows to return. + Arguments: + n: Number of rows to return. Examples: >>> import narwhals as nw @@ -2087,7 +2134,7 @@ def round(self: Self, decimals: int = 0) -> Self: r""" Round underlying floating point data by `decimals` digits. - Arguments + Arguments: decimals: Number of decimals to round by. Notes: @@ -2137,7 +2184,7 @@ def to_dummies( r""" Get dummy/indicator variables. - Arguments + Arguments: separator: Separator/delimiter used when generating column names. drop_first: Remove the first category from the variable being encoded. @@ -2281,6 +2328,44 @@ def to_arrow(self: Self) -> pa.Array: """ return self._compliant_series.to_arrow() + def mode(self: Self) -> Self: + r""" + Compute the most occurring value(s). + + Can return multiple values. + + Examples: + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals as nw + + >>> data = [1, 1, 2, 2, 3] + >>> s_pd = pd.Series(name="a", data=data) + >>> s_pl = pl.Series(name="a", values=data) + + We define a library agnostic function: + + >>> @nw.narwhalify + ... def func(s): + ... return s.mode().sort() + + We can then pass either pandas or Polars to `func`: + + >>> func(s_pd) + 0 1 + 1 2 + Name: a, dtype: int64 + + >>> func(s_pl) # doctest:+NORMALIZE_WHITESPACE + shape: (2,) + Series: 'a' [i64] + [ + 1 + 2 + ] + """ + return self._from_compliant_series(self._compliant_series.mode()) + @property def str(self) -> SeriesStringNamespace: return SeriesStringNamespace(self) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index 8363b36e9..0720980c1 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -40,7 +40,6 @@ from narwhals.expr import Then as NwThen from narwhals.expr import When as NwWhen from narwhals.expr import when as nw_when -from narwhals.functions import concat from narwhals.functions import show_versions from narwhals.schema import Schema as NwSchema from narwhals.series import Series as NwSeries @@ -80,20 +79,25 @@ class DataFrame(NwDataFrame[IntoDataFrameT]): def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ... @overload def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ... + @overload + def __getitem__(self, item: tuple[slice, Sequence[int]]) -> Self: ... @overload def __getitem__(self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap] @overload def __getitem__(self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ... + @overload + def __getitem__(self, item: tuple[slice, Sequence[str]]) -> Self: ... @overload def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap] @overload def __getitem__(self, item: Sequence[int]) -> Self: ... - @overload - def __getitem__(self, item: str) -> Series: ... + def __getitem__(self, item: str) -> Series: ... # type: ignore[overload-overlap] + @overload + def __getitem__(self, item: Sequence[str]) -> Self: ... @overload def __getitem__(self, item: slice) -> Self: ... @@ -572,26 +576,50 @@ def from_native( @overload def from_native( - native_dataframe: IntoDataFrameT | T, + native_dataframe: IntoDataFrameT, + *, + strict: Literal[False], + eager_only: None = ..., + eager_or_interchange_only: Literal[True], + series_only: None = ..., + allow_series: None = ..., +) -> DataFrame[IntoDataFrameT]: ... + + +@overload +def from_native( + native_dataframe: T, *, strict: Literal[False], eager_only: None = ..., eager_or_interchange_only: Literal[True], series_only: None = ..., allow_series: None = ..., -) -> DataFrame[IntoDataFrameT] | T: ... +) -> T: ... @overload def from_native( - native_dataframe: IntoDataFrameT | T, + native_dataframe: IntoDataFrameT, *, strict: Literal[False], eager_only: Literal[True], eager_or_interchange_only: None = ..., series_only: None = ..., allow_series: None = ..., -) -> DataFrame[IntoDataFrameT] | T: ... +) -> DataFrame[IntoDataFrameT]: ... + + +@overload +def from_native( + native_dataframe: T, + *, + strict: Literal[False], + eager_only: Literal[True], + eager_or_interchange_only: None = ..., + series_only: None = ..., + allow_series: None = ..., +) -> T: ... @overload @@ -620,14 +648,26 @@ def from_native( @overload def from_native( - native_dataframe: IntoFrameT | T, + native_dataframe: IntoFrameT, *, strict: Literal[False], eager_only: None = ..., eager_or_interchange_only: None = ..., series_only: None = ..., allow_series: None = ..., -) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | T: ... +) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT]: ... + + +@overload +def from_native( + native_dataframe: T, + *, + strict: Literal[False], + eager_only: None = ..., + eager_or_interchange_only: None = ..., + series_only: None = ..., + allow_series: None = ..., +) -> T: ... @overload @@ -1375,6 +1415,128 @@ def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return _stableify(nw.mean_horizontal(*exprs)) +@overload +def concat( + items: Iterable[DataFrame[Any]], + *, + how: Literal["horizontal", "vertical"] = "vertical", +) -> DataFrame[Any]: ... + + +@overload +def concat( + items: Iterable[LazyFrame[Any]], + *, + how: Literal["horizontal", "vertical"] = "vertical", +) -> LazyFrame[Any]: ... + + +def concat( + items: Iterable[DataFrame[Any] | LazyFrame[Any]], + *, + how: Literal["horizontal", "vertical"] = "vertical", +) -> DataFrame[Any] | LazyFrame[Any]: + """ + Concatenate multiple DataFrames, LazyFrames into a single entity. + + Arguments: + items: DataFrames, LazyFrames to concatenate. + + how: {'vertical', 'horizontal'} + * vertical: Stacks Series from DataFrames vertically and fills with `null` + if the lengths don't match. + * horizontal: Stacks Series from DataFrames horizontally and fills with `null` + if the lengths don't match. + + Returns: + A new DataFrame, Lazyframe resulting from the concatenation. + + Raises: + NotImplementedError: The items to concatenate should either all be eager, or all lazy + + Examples: + + Let's take an example of vertical concatenation: + + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals.stable.v1 as nw + >>> data_1 = {"a": [1, 2, 3], "b": [4, 5, 6]} + >>> data_2 = {"a": [5, 2], "b": [1, 4]} + + >>> df_pd_1 = pd.DataFrame(data_1) + >>> df_pd_2 = pd.DataFrame(data_2) + >>> df_pl_1 = pl.DataFrame(data_1) + >>> df_pl_2 = pl.DataFrame(data_2) + + Let's define a dataframe-agnostic function: + + >>> @nw.narwhalify + ... def func(df1, df2): + ... return nw.concat([df1, df2], how="vertical") + + >>> func(df_pd_1, df_pd_2) + a b + 0 1 4 + 1 2 5 + 2 3 6 + 0 5 1 + 1 2 4 + >>> func(df_pl_1, df_pl_2) + shape: (5, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 1 ┆ 4 │ + │ 2 ┆ 5 │ + │ 3 ┆ 6 │ + │ 5 ┆ 1 │ + │ 2 ┆ 4 │ + └─────┴─────┘ + + Let's look at case a for horizontal concatenation: + + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals.stable.v1 as nw + >>> data_1 = {"a": [1, 2, 3], "b": [4, 5, 6]} + >>> data_2 = {"c": [5, 2], "d": [1, 4]} + + >>> df_pd_1 = pd.DataFrame(data_1) + >>> df_pd_2 = pd.DataFrame(data_2) + >>> df_pl_1 = pl.DataFrame(data_1) + >>> df_pl_2 = pl.DataFrame(data_2) + + Defining a dataframe-agnostic function: + + >>> @nw.narwhalify + ... def func(df1, df2): + ... return nw.concat([df1, df2], how="horizontal") + + >>> func(df_pd_1, df_pd_2) + a b c d + 0 1 4 5.0 1.0 + 1 2 5 2.0 4.0 + 2 3 6 NaN NaN + + >>> func(df_pl_1, df_pl_2) + shape: (3, 4) + ┌─────┬─────┬──────┬──────┐ + │ a ┆ b ┆ c ┆ d │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞═════╪═════╪══════╪══════╡ + │ 1 ┆ 4 ┆ 5 ┆ 1 │ + │ 2 ┆ 5 ┆ 2 ┆ 4 │ + │ 3 ┆ 6 ┆ null ┆ null │ + └─────┴─────┴──────┴──────┘ + + """ + return _stableify(nw.concat(items, how=how)) # type: ignore[no-any-return] + + def is_ordered_categorical(series: Series) -> bool: """ Return whether indices of categories are semantically meaningful. @@ -1417,7 +1579,7 @@ def is_ordered_categorical(series: Series) -> bool: def maybe_align_index(lhs: T, rhs: Series | DataFrame[Any] | LazyFrame[Any]) -> T: """ - Align `lhs` to the Index of `rhs, if they're both pandas-like. + Align `lhs` to the Index of `rhs`, if they're both pandas-like. Notes: This is only really intended for backwards-compatibility purposes, @@ -1725,13 +1887,13 @@ def from_dict( >>> import narwhals.stable.v1 as nw >>> data = {"a": [1, 2, 3], "b": [4, 5, 6]} - Let's define a dataframe-agnostic function: + Let's create a new dataframe of the same class as the dataframe we started with, from a dict of new data: >>> @nw.narwhalify ... def func(df): - ... data = {"c": [5, 2], "d": [1, 4]} + ... new_data = {"c": [5, 2], "d": [1, 4]} ... native_namespace = nw.get_native_namespace(df) - ... return nw.from_dict(data, native_namespace=native_namespace) + ... return nw.from_dict(new_data, native_namespace=native_namespace) Let's see what happens when passing pandas / Polars input: diff --git a/narwhals/translate.py b/narwhals/translate.py index c19ea7192..69a99ea2b 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -18,6 +18,8 @@ from narwhals.dependencies import is_cudf_dataframe from narwhals.dependencies import is_cudf_series from narwhals.dependencies import is_dask_dataframe +from narwhals.dependencies import is_duckdb_relation +from narwhals.dependencies import is_ibis_table from narwhals.dependencies import is_modin_dataframe from narwhals.dependencies import is_modin_series from narwhals.dependencies import is_pandas_dataframe @@ -107,26 +109,50 @@ def from_native( @overload def from_native( - native_object: IntoDataFrameT | T, + native_object: IntoDataFrameT, + *, + strict: Literal[False], + eager_only: None = ..., + eager_or_interchange_only: Literal[True], + series_only: None = ..., + allow_series: None = ..., +) -> DataFrame[IntoDataFrameT]: ... + + +@overload +def from_native( + native_object: T, *, strict: Literal[False], eager_only: None = ..., eager_or_interchange_only: Literal[True], series_only: None = ..., allow_series: None = ..., -) -> DataFrame[IntoDataFrameT] | T: ... +) -> T: ... @overload def from_native( - native_object: IntoDataFrameT | T, + native_object: IntoDataFrameT, *, strict: Literal[False], eager_only: Literal[True], eager_or_interchange_only: None = ..., series_only: None = ..., allow_series: None = ..., -) -> DataFrame[IntoDataFrameT] | T: ... +) -> DataFrame[IntoDataFrameT]: ... + + +@overload +def from_native( + native_object: T, + *, + strict: Literal[False], + eager_only: Literal[True], + eager_or_interchange_only: None = ..., + series_only: None = ..., + allow_series: None = ..., +) -> T: ... @overload @@ -155,14 +181,26 @@ def from_native( @overload def from_native( - native_object: IntoFrameT | T, + native_object: IntoFrameT, + *, + strict: Literal[False], + eager_only: None = ..., + eager_or_interchange_only: None = ..., + series_only: None = ..., + allow_series: None = ..., +) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT]: ... + + +@overload +def from_native( + native_object: T, *, strict: Literal[False], eager_only: None = ..., eager_or_interchange_only: None = ..., series_only: None = ..., allow_series: None = ..., -) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | T: ... +) -> T: ... @overload @@ -176,8 +214,8 @@ def from_native( allow_series: None = ..., ) -> DataFrame[IntoDataFrameT]: """ - from_native(df, strict=True, eager_or_interchange_only=True, allow_series=True) - from_native(df, eager_or_interchange_only=True, allow_series=True) + from_native(df, strict=True, eager_or_interchange_only=True) + from_native(df, eager_or_interchange_only=True) """ @@ -192,8 +230,8 @@ def from_native( allow_series: None = ..., ) -> DataFrame[IntoDataFrameT]: """ - from_native(df, strict=True, eager_only=True, allow_series=True) - from_native(df, eager_only=True, allow_series=True) + from_native(df, strict=True, eager_only=True) + from_native(df, eager_only=True) """ @@ -208,8 +246,8 @@ def from_native( allow_series: Literal[True], ) -> DataFrame[Any] | LazyFrame[Any] | Series: """ - from_native(df, strict=True, eager_only=True) - from_native(df, eager_only=True) + from_native(df, strict=True, allow_series=True) + from_native(df, allow_series=True) """ @@ -295,6 +333,8 @@ def from_native( # noqa: PLR0915 from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.series import ArrowSeries from narwhals._dask.dataframe import DaskLazyFrame + from narwhals._duckdb.dataframe import DuckDBInterchangeFrame + from narwhals._ibis.dataframe import IbisInterchangeFrame from narwhals._interchange.dataframe import InterchangeFrame from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.series import PandasLikeSeries @@ -510,6 +550,32 @@ def from_native( # noqa: PLR0915 level="full", ) + # DuckDB + elif is_duckdb_relation(native_object): + if eager_only or series_only: # pragma: no cover + msg = ( + "Cannot only use `series_only=True` or `eager_only=False` " + "with DuckDB Relation" + ) + raise TypeError(msg) + return DataFrame( + DuckDBInterchangeFrame(native_object), + level="interchange", + ) + + # Ibis + elif is_ibis_table(native_object): # pragma: no cover + if eager_only or series_only: + msg = ( + "Cannot only use `series_only=True` or `eager_only=False` " + "with Ibis table" + ) + raise TypeError(msg) + return DataFrame( + IbisInterchangeFrame(native_object), + level="interchange", + ) + # Interchange protocol elif hasattr(native_object, "__dataframe__"): if eager_only or series_only: @@ -519,7 +585,7 @@ def from_native( # noqa: PLR0915 ) raise TypeError(msg) return DataFrame( - InterchangeFrame(native_object.__dataframe__()), + InterchangeFrame(native_object), level="interchange", ) diff --git a/narwhals/utils.py b/narwhals/utils.py index 6c1b5c1b4..ec3c722d4 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -153,7 +153,7 @@ def validate_laziness(items: Iterable[Any]) -> None: def maybe_align_index(lhs: T, rhs: Series | BaseFrame[Any]) -> T: """ - Align `lhs` to the Index of `rhs, if they're both pandas-like. + Align `lhs` to the Index of `rhs`, if they're both pandas-like. Notes: This is only really intended for backwards-compatibility purposes, diff --git a/pyproject.toml b/pyproject.toml index 2aa6c8005..5ec7fef5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "narwhals" -version = "1.5.5" +version = "1.7.0" authors = [ { name="Marco Gorelli", email="33491632+MarcoGorelli@users.noreply.github.com" }, ] @@ -23,6 +23,12 @@ exclude = [ "/docs", "/tests", "/tpch", + "/utils", + ".gitignore", + "CONTRIBUTING.md", + "mkdocs.yml", + "noxfile.py", + "requirements-dev.txt", ] [project.optional-dependencies] @@ -70,7 +76,18 @@ lint.ignore = [ [tool.ruff.lint.per-file-ignores] "tests/*" = ["S101"] +"tpch/tests/*" = ["S101"] "utils/*" = ["S311", "PTH123"] +"tpch/execute/*" = ["T201"] +"tpch/notebooks/*" = [ + "ANN001", + "ANN201", + "EM101", + "EXE002", + "PTH123", + "T203", + "TRY003", +] [tool.ruff.lint.pydocstyle] convention = "google" @@ -108,7 +125,11 @@ env = [ plugins = ["covdefaults"] [tool.coverage.report] -omit = ['narwhals/typing.py'] +omit = [ + 'narwhals/typing.py', + # we can run this in every environment that we measure coverage on due to upper-bound constraits + 'narwhals/_ibis/*', +] exclude_also = [ "> POLARS_VERSION", "if sys.version_info() <", diff --git a/requirements-dev.txt b/requirements-dev.txt index 2424d4ea1..23ff1757e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,6 @@ +tqdm covdefaults +duckdb pandas polars pre-commit diff --git a/tests/conftest.py b/tests/conftest.py index cdf4e0be6..011b83265 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -72,9 +72,14 @@ def polars_lazy_constructor(obj: Any) -> pl.LazyFrame: return pl.LazyFrame(obj) -def dask_lazy_constructor(obj: Any) -> IntoFrame: # pragma: no cover +def dask_lazy_p1_constructor(obj: Any) -> IntoFrame: # pragma: no cover dd = get_dask_dataframe() - return dd.from_pandas(pd.DataFrame(obj)) # type: ignore[no-any-return] + return dd.from_dict(obj, npartitions=1) # type: ignore[no-any-return] + + +def dask_lazy_p2_constructor(obj: Any) -> IntoFrame: # pragma: no cover + dd = get_dask_dataframe() + return dd.from_dict(obj, npartitions=2) # type: ignore[no-any-return] def pyarrow_table_constructor(obj: Any) -> IntoDataFrame: @@ -98,7 +103,7 @@ def pyarrow_table_constructor(obj: Any) -> IntoDataFrame: if get_cudf() is not None: eager_constructors.append(cudf_constructor) # pragma: no cover if get_dask_dataframe() is not None: # pragma: no cover - lazy_constructors.append(dask_lazy_constructor) # type: ignore # noqa: PGH003 + lazy_constructors.extend([dask_lazy_p1_constructor, dask_lazy_p2_constructor]) # type: ignore # noqa: PGH003 @pytest.fixture(params=eager_constructors) diff --git a/tests/expr_and_series/arithmetic_test.py b/tests/expr_and_series/arithmetic_test.py index 47d3e8ff0..7ff945c80 100644 --- a/tests/expr_and_series/arithmetic_test.py +++ b/tests/expr_and_series/arithmetic_test.py @@ -149,7 +149,7 @@ def test_truediv_same_dims(constructor_eager: Any, request: Any) -> None: compare_dicts({"a": result}, {"a": [2, 1, 1 / 3]}) -@pytest.mark.slow() +@pytest.mark.slow @given( # type: ignore[misc] left=st.integers(-100, 100), right=st.integers(-100, 100), @@ -189,7 +189,7 @@ def test_floordiv(left: int, right: int) -> None: compare_dicts(result, expected) -@pytest.mark.slow() +@pytest.mark.slow @given( # type: ignore[misc] left=st.integers(-100, 100), right=st.integers(-100, 100), diff --git a/tests/expr_and_series/dt/datetime_attributes_test.py b/tests/expr_and_series/dt/datetime_attributes_test.py index 4d59567df..22e20590e 100644 --- a/tests/expr_and_series/dt/datetime_attributes_test.py +++ b/tests/expr_and_series/dt/datetime_attributes_test.py @@ -42,6 +42,8 @@ def test_datetime_attributes( and "pyarrow" not in str(constructor) ): request.applymarker(pytest.mark.xfail) + if attribute == "date" and "cudf" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(getattr(nw.col("a").dt, attribute)()) @@ -73,6 +75,8 @@ def test_datetime_attributes_series( and "pyarrow" not in str(constructor_eager) ): request.applymarker(pytest.mark.xfail) + if attribute == "date" and "cudf" in str(constructor_eager): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor_eager(data), eager_only=True) result = df.select(getattr(df["a"].dt, attribute)()) @@ -82,6 +86,8 @@ def test_datetime_attributes_series( def test_datetime_chained_attributes(request: Any, constructor_eager: Any) -> None: if "pandas" in str(constructor_eager) and "pyarrow" not in str(constructor_eager): request.applymarker(pytest.mark.xfail) + if "cudf" in str(constructor_eager): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor_eager(data), eager_only=True) result = df.select(df["a"].dt.date().dt.year()) diff --git a/tests/expr_and_series/dt/ordinal_day_test.py b/tests/expr_and_series/dt/ordinal_day_test.py index 1cb464259..2681188df 100644 --- a/tests/expr_and_series/dt/ordinal_day_test.py +++ b/tests/expr_and_series/dt/ordinal_day_test.py @@ -17,7 +17,7 @@ parse_version(pd.__version__) < parse_version("2.0.0"), reason="pyarrow dtype not available", ) -@pytest.mark.slow() +@pytest.mark.slow def test_ordinal_day(dates: datetime) -> None: result_pd = nw.from_native(pd.Series([dates]), series_only=True).dt.ordinal_day()[0] result_pdms = nw.from_native( diff --git a/tests/expr_and_series/dt/total_minutes_test.py b/tests/expr_and_series/dt/total_minutes_test.py index f2469e495..bcd664442 100644 --- a/tests/expr_and_series/dt/total_minutes_test.py +++ b/tests/expr_and_series/dt/total_minutes_test.py @@ -22,7 +22,7 @@ parse_version(pd.__version__) < parse_version("2.2.0"), reason="pyarrow dtype not available", ) -@pytest.mark.slow() +@pytest.mark.slow def test_total_minutes(timedeltas: timedelta) -> None: result_pd = nw.from_native( pd.Series([timedeltas]), series_only=True diff --git a/tests/expr_and_series/is_duplicated_test.py b/tests/expr_and_series/is_duplicated_test.py index 5fa060312..71d165749 100644 --- a/tests/expr_and_series/is_duplicated_test.py +++ b/tests/expr_and_series/is_duplicated_test.py @@ -3,19 +3,13 @@ import narwhals.stable.v1 as nw from tests.utils import compare_dicts -data = { - "a": [1, 1, 2], - "b": [1, 2, 3], -} +data = {"a": [1, 1, 2], "b": [1, 2, 3], "index": [0, 1, 2]} def test_is_duplicated_expr(constructor: Any) -> None: df = nw.from_native(constructor(data)) - result = df.select(nw.all().is_duplicated()) - expected = { - "a": [True, True, False], - "b": [False, False, False], - } + result = df.select(nw.col("a", "b").is_duplicated(), "index").sort("index") + expected = {"a": [True, True, False], "b": [False, False, False], "index": [0, 1, 2]} compare_dicts(result, expected) diff --git a/tests/expr_and_series/is_unique_test.py b/tests/expr_and_series/is_unique_test.py index 8bddbb647..d203c1635 100644 --- a/tests/expr_and_series/is_unique_test.py +++ b/tests/expr_and_series/is_unique_test.py @@ -6,15 +6,17 @@ data = { "a": [1, 1, 2], "b": [1, 2, 3], + "index": [0, 1, 2], } def test_is_unique_expr(constructor: Any) -> None: df = nw.from_native(constructor(data)) - result = df.select(nw.all().is_unique()) + result = df.select(nw.col("a", "b").is_unique(), "index").sort("index") expected = { "a": [False, False, True], "b": [True, True, True], + "index": [0, 1, 2], } compare_dicts(result, expected) diff --git a/tests/expr_and_series/mode_test.py b/tests/expr_and_series/mode_test.py new file mode 100644 index 000000000..33a0bef5a --- /dev/null +++ b/tests/expr_and_series/mode_test.py @@ -0,0 +1,37 @@ +from typing import Any + +import pytest + +import narwhals.stable.v1 as nw +from tests.utils import compare_dicts + +data = { + "a": [1, 1, 2, 2, 3], + "b": [1, 2, 3, 3, 4], +} + + +def test_mode_single_expr(constructor: Any, request: Any) -> None: + if "dask" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.select(nw.col("a").mode()).sort("a") + expected = {"a": [1, 2]} + compare_dicts(result, expected) + + +def test_mode_multi_expr(constructor: Any, request: Any) -> None: + if "dask" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) + result = df.select(nw.col("a", "b").mode()).sort("a", "b") + expected = {"a": [1, 2], "b": [3, 3]} + compare_dicts(result, expected) + + +def test_mode_series(constructor_eager: Any) -> None: + series = nw.from_native(constructor_eager(data), eager_only=True)["a"] + result = series.mode().sort() + expected = {"a": [1, 2]} + compare_dicts({"a": result}, expected) diff --git a/tests/expr_and_series/over_test.py b/tests/expr_and_series/over_test.py index fb01a3cfd..17b07cc1e 100644 --- a/tests/expr_and_series/over_test.py +++ b/tests/expr_and_series/over_test.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext as does_not_raise from typing import Any import pytest @@ -14,26 +15,48 @@ def test_over_single(constructor: Any) -> None: df = nw.from_native(constructor(data)) - result = df.with_columns(c_max=nw.col("c").max().over("a")) expected = { "a": ["a", "a", "b", "b", "b"], "b": [1, 2, 3, 5, 3], "c": [5, 4, 3, 2, 1], "c_max": [5, 5, 3, 3, 3], } - compare_dicts(result, expected) + + context = ( + pytest.raises( + NotImplementedError, + match="`Expr.over` is not supported for Dask backend with multiple partitions.", + ) + if "dask_lazy_p2" in str(constructor) + else does_not_raise() + ) + + with context: + result = df.with_columns(c_max=nw.col("c").max().over("a")) + compare_dicts(result, expected) def test_over_multiple(constructor: Any) -> None: df = nw.from_native(constructor(data)) - result = df.with_columns(c_min=nw.col("c").min().over("a", "b")) expected = { "a": ["a", "a", "b", "b", "b"], "b": [1, 2, 3, 5, 3], "c": [5, 4, 3, 2, 1], "c_min": [5, 4, 1, 2, 1], } - compare_dicts(result, expected) + + context = ( + pytest.raises( + NotImplementedError, + match="`Expr.over` is not supported for Dask backend with multiple partitions.", + ) + if "dask_lazy_p2" in str(constructor) + else does_not_raise() + ) + + with context: + result = df.with_columns(c_min=nw.col("c").min().over("a", "b")) + compare_dicts(result, expected) def test_over_invalid(request: Any, constructor: Any) -> None: diff --git a/tests/expr_and_series/quantile_test.py b/tests/expr_and_series/quantile_test.py index d9064541f..5b8ff9334 100644 --- a/tests/expr_and_series/quantile_test.py +++ b/tests/expr_and_series/quantile_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextlib import nullcontext as does_not_raise from typing import Any from typing import Literal @@ -28,12 +29,24 @@ def test_quantile_expr( ) -> None: if "dask" in str(constructor) and interpolation != "linear": request.applymarker(pytest.mark.xfail) + q = 0.3 data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_raw = constructor(data) df = nw.from_native(df_raw) - result = df.select(nw.all().quantile(quantile=q, interpolation=interpolation)) - compare_dicts(result, expected) + + context = ( + pytest.raises( + NotImplementedError, + match="`Expr.quantile` is not supported for Dask backend with multiple partitions.", + ) + if "dask_lazy_p2" in str(constructor) + else does_not_raise() + ) + + with context: + result = df.select(nw.all().quantile(quantile=q, interpolation=interpolation)) + compare_dicts(result, expected) @pytest.mark.parametrize( diff --git a/tests/frame/arrow_c_stream_test.py b/tests/frame/arrow_c_stream_test.py index 7a3403f69..cb856adf9 100644 --- a/tests/frame/arrow_c_stream_test.py +++ b/tests/frame/arrow_c_stream_test.py @@ -10,6 +10,9 @@ @pytest.mark.skipif( parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" ) +@pytest.mark.skipif( + parse_version(pa.__version__) < (16, 0, 0), reason="too old for pycapsule in PyArrow" +) def test_arrow_c_stream_test() -> None: df = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True) result = pa.table(df) @@ -20,6 +23,9 @@ def test_arrow_c_stream_test() -> None: @pytest.mark.skipif( parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" ) +@pytest.mark.skipif( + parse_version(pa.__version__) < (16, 0, 0), reason="too old for pycapsule in PyArrow" +) def test_arrow_c_stream_test_invalid(monkeypatch: pytest.MonkeyPatch) -> None: # "poison" the dunder method to make sure it actually got called above monkeypatch.setattr( @@ -33,6 +39,9 @@ def test_arrow_c_stream_test_invalid(monkeypatch: pytest.MonkeyPatch) -> None: @pytest.mark.skipif( parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" ) +@pytest.mark.skipif( + parse_version(pa.__version__) < (16, 0, 0), reason="too old for pycapsule in PyArrow" +) def test_arrow_c_stream_test_fallback(monkeypatch: pytest.MonkeyPatch) -> None: # Check that fallback to PyArrow works monkeypatch.delattr("polars.DataFrame.__arrow_c_stream__") diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index 970220bf2..a52759128 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -6,14 +6,12 @@ from tests.utils import compare_dicts -def test_concat_horizontal(constructor: Any, request: Any) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_concat_horizontal(constructor: Any) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df_left = nw.from_native(constructor(data)) + df_left = nw.from_native(constructor(data)).lazy() data_right = {"c": [6, 12, -1], "d": [0, -4, 2]} - df_right = nw.from_native(constructor(data_right)) + df_right = nw.from_native(constructor(data_right)).lazy() result = nw.concat([df_left, df_right], how="horizontal") expected = { @@ -29,12 +27,10 @@ def test_concat_horizontal(constructor: Any, request: Any) -> None: nw.concat([]) -def test_concat_vertical(constructor: Any, request: Any) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_concat_vertical(constructor: Any) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_left = ( - nw.from_native(constructor(data)).rename({"a": "c", "b": "d"}).drop("z").lazy() + nw.from_native(constructor(data)).lazy().rename({"a": "c", "b": "d"}).drop("z") ) data_right = {"c": [6, 12, -1], "d": [0, -4, 2]} diff --git a/tests/frame/filter_test.py b/tests/frame/filter_test.py index a8d3144aa..e7a289feb 100644 --- a/tests/frame/filter_test.py +++ b/tests/frame/filter_test.py @@ -1,5 +1,8 @@ +from contextlib import nullcontext as does_not_raise from typing import Any +import pytest + import narwhals.stable.v1 as nw from tests.utils import compare_dicts @@ -15,6 +18,17 @@ def test_filter(constructor: Any) -> None: def test_filter_with_boolean_list(constructor: Any) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) - result = df.filter([False, True, True]) - expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]} - compare_dicts(result, expected) + + context = ( + pytest.raises( + NotImplementedError, + match="`LazyFrame.filter` is not supported for Dask backend with boolean masks.", + ) + if "dask" in str(constructor) + else does_not_raise() + ) + + with context: + result = df.filter([False, True, True]) + expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]} + compare_dicts(result, expected) diff --git a/tests/frame/interchange_schema_test.py b/tests/frame/interchange_schema_test.py index df901522a..afec06831 100644 --- a/tests/frame/interchange_schema_test.py +++ b/tests/frame/interchange_schema_test.py @@ -1,9 +1,14 @@ from datetime import date +from datetime import datetime +from datetime import timedelta +import duckdb +import pandas as pd import polars as pl import pytest import narwhals.stable.v1 as nw +from narwhals.utils import parse_version def test_interchange_schema() -> None: @@ -63,6 +68,158 @@ def test_interchange_schema() -> None: assert df["a"].dtype == nw.Int64 +@pytest.mark.filterwarnings("ignore:.*locale specific date formats") +def test_interchange_schema_ibis( + tmpdir: pytest.TempdirFactory, +) -> None: # pragma: no cover + ibis = pytest.importorskip("ibis") + df_pl = pl.DataFrame( + { + "a": [1, 1, 2], + "b": [4, 5, 6], + "c": [4, 5, 6], + "d": [4, 5, 6], + "e": [4, 5, 6], + "f": [4, 5, 6], + "g": [4, 5, 6], + "h": [4, 5, 6], + "i": [4, 5, 6], + "j": [4, 5, 6], + "k": ["fdafsd", "fdas", "ad"], + "l": ["fdafsd", "fdas", "ad"], + "m": [date(2021, 1, 1), date(2021, 1, 1), date(2021, 1, 1)], + "n": [datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 1)], + "o": [True, True, False], + }, + schema={ + "a": pl.Int64, + "b": pl.Int32, + "c": pl.Int16, + "d": pl.Int8, + "e": pl.UInt64, + "f": pl.UInt32, + "g": pl.UInt16, + "h": pl.UInt8, + "i": pl.Float64, + "j": pl.Float32, + "k": pl.String, + "l": pl.Categorical, + "m": pl.Date, + "n": pl.Datetime, + "o": pl.Boolean, + }, + ) + filepath = str(tmpdir / "file.parquet") # type: ignore[operator] + df_pl.write_parquet(filepath) + tbl = ibis.read_parquet(filepath) + df = nw.from_native(tbl, eager_or_interchange_only=True) + result = df.schema + if parse_version(ibis.__version__) > (6, 0, 0): + expected = { + "a": nw.Int64, + "b": nw.Int32, + "c": nw.Int16, + "d": nw.Int8, + "e": nw.UInt64, + "f": nw.UInt32, + "g": nw.UInt16, + "h": nw.UInt8, + "i": nw.Float64, + "j": nw.Float32, + "k": nw.String, + "l": nw.String, + "m": nw.Date, + "n": nw.Datetime, + "o": nw.Boolean, + } + else: + # Old versions of Ibis would read the file in + # with different data types + expected = { + "a": nw.Int64, + "b": nw.Int32, + "c": nw.Int16, + "d": nw.Int32, + "e": nw.Int32, + "f": nw.Int32, + "g": nw.Int32, + "h": nw.Int32, + "i": nw.Float64, + "j": nw.Float64, + "k": nw.String, + "l": nw.String, + "m": nw.Date, + "n": nw.Datetime, + "o": nw.Boolean, + } + assert result == expected + assert df["a"].dtype == nw.Int64 + + +def test_interchange_schema_duckdb() -> None: + df_pl = pl.DataFrame( # noqa: F841 + { + "a": [1, 1, 2], + "b": [4, 5, 6], + "c": [4, 5, 6], + "d": [4, 5, 6], + "e": [4, 5, 6], + "f": [4, 5, 6], + "g": [4, 5, 6], + "h": [4, 5, 6], + "i": [4, 5, 6], + "j": [4, 5, 6], + "k": ["fdafsd", "fdas", "ad"], + "l": ["fdafsd", "fdas", "ad"], + "m": [date(2021, 1, 1), date(2021, 1, 1), date(2021, 1, 1)], + "n": [datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 1)], + "o": [timedelta(1)] * 3, + "p": [True, True, False], + }, + schema={ + "a": pl.Int64, + "b": pl.Int32, + "c": pl.Int16, + "d": pl.Int8, + "e": pl.UInt64, + "f": pl.UInt32, + "g": pl.UInt16, + "h": pl.UInt8, + "i": pl.Float64, + "j": pl.Float32, + "k": pl.String, + "l": pl.Categorical, + "m": pl.Date, + "n": pl.Datetime, + "o": pl.Duration, + "p": pl.Boolean, + }, + ) + rel = duckdb.sql("select * from df_pl") + df = nw.from_native(rel, eager_or_interchange_only=True) + result = df.schema + expected = { + "a": nw.Int64, + "b": nw.Int32, + "c": nw.Int16, + "d": nw.Int8, + "e": nw.UInt64, + "f": nw.UInt32, + "g": nw.UInt16, + "h": nw.UInt8, + "i": nw.Float64, + "j": nw.Float32, + "k": nw.String, + "l": nw.String, + "m": nw.Date, + "n": nw.Datetime, + "o": nw.Duration, + "p": nw.Boolean, + } + assert result == expected + assert df["a"].dtype == nw.Int64 + + def test_invalid() -> None: df = pl.DataFrame({"a": [1, 2, 3]}).__dataframe__() with pytest.raises( @@ -82,3 +239,10 @@ def test_get_level() -> None: nw.get_level(nw.from_native(df.__dataframe__(), eager_or_interchange_only=True)) == "interchange" ) + + +def test_unknown_dtype() -> None: + df = pd.DataFrame({"a": [1, 2, 3]}) + rel = duckdb.from_df(df).select("cast(a as int128) as a") + result = nw.from_native(rel).schema + assert result == {"a": nw.Unknown} diff --git a/tests/frame/join_test.py b/tests/frame/join_test.py index d5c88ee4c..18e9aae64 100644 --- a/tests/frame/join_test.py +++ b/tests/frame/join_test.py @@ -1,68 +1,138 @@ from __future__ import annotations import re +from datetime import datetime from typing import Any +from typing import Literal import pandas as pd import pytest import narwhals.stable.v1 as nw from narwhals.utils import Implementation +from narwhals.utils import parse_version from tests.utils import compare_dicts def test_inner_join_two_keys(constructor: Any) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + data = { + "antananarivo": [1, 3, 2], + "bob": [4, 4, 6], + "zorro": [7.0, 8, 9], + "index": [0, 1, 2], + } df = nw.from_native(constructor(data)) df_right = df - result = df.join(df_right, left_on=["a", "b"], right_on=["a", "b"], how="inner") # type: ignore[arg-type] + result = df.join( + df_right, # type: ignore[arg-type] + left_on=["antananarivo", "bob"], + right_on=["antananarivo", "bob"], + how="inner", + ) + result_on = df.join(df_right, on=["antananarivo", "bob"], how="inner") # type: ignore[arg-type] + result = result.sort("index").drop("index_right") + result_on = result_on.sort("index").drop("index_right") expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "z": [7.0, 8, 9], - "z_right": [7.0, 8, 9], + "antananarivo": [1, 3, 2], + "bob": [4, 4, 6], + "zorro": [7.0, 8, 9], + "zorro_right": [7.0, 8, 9], + "index": [0, 1, 2], } compare_dicts(result, expected) + compare_dicts(result_on, expected) def test_inner_join_single_key(constructor: Any) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + data = { + "antananarivo": [1, 3, 2], + "bob": [4, 4, 6], + "zorro": [7.0, 8, 9], + "index": [0, 1, 2], + } df = nw.from_native(constructor(data)) df_right = df - result = df.join(df_right, left_on="a", right_on="a", how="inner") # type: ignore[arg-type] + result = df.join( + df_right, # type: ignore[arg-type] + left_on="antananarivo", + right_on="antananarivo", + how="inner", + ).sort("index") + result_on = df.join(df_right, on="antananarivo", how="inner").sort("index") # type: ignore[arg-type] + result = result.drop("index_right") + result_on = result_on.drop("index_right") expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "b_right": [4, 4, 6], - "z": [7.0, 8, 9], - "z_right": [7.0, 8, 9], + "antananarivo": [1, 3, 2], + "bob": [4, 4, 6], + "bob_right": [4, 4, 6], + "zorro": [7.0, 8, 9], + "zorro_right": [7.0, 8, 9], + "index": [0, 1, 2], } compare_dicts(result, expected) + compare_dicts(result_on, expected) def test_cross_join(constructor: Any) -> None: - data = {"a": [1, 3, 2]} + data = {"antananarivo": [1, 3, 2]} df = nw.from_native(constructor(data)) - result = df.join(df, how="cross").sort("a", "a_right") # type: ignore[arg-type] + result = df.join(df, how="cross").sort("antananarivo", "antananarivo_right") # type: ignore[arg-type] expected = { - "a": [1, 1, 1, 2, 2, 2, 3, 3, 3], - "a_right": [1, 2, 3, 1, 2, 3, 1, 2, 3], + "antananarivo": [1, 1, 1, 2, 2, 2, 3, 3, 3], + "antananarivo_right": [1, 2, 3, 1, 2, 3, 1, 2, 3], } compare_dicts(result, expected) - with pytest.raises(ValueError, match="Can not pass left_on, right_on for cross join"): - df.join(df, how="cross", left_on="a") # type: ignore[arg-type] + with pytest.raises( + ValueError, match="Can not pass `left_on`, `right_on` or `on` keys for cross join" + ): + df.join(df, how="cross", left_on="antananarivo") # type: ignore[arg-type] + + +@pytest.mark.parametrize("how", ["inner", "left"]) +@pytest.mark.parametrize("suffix", ["_right", "_custom_suffix"]) +def test_suffix(constructor: Any, how: str, suffix: str) -> None: + data = { + "antananarivo": [1, 3, 2], + "bob": [4, 4, 6], + "zorro": [7.0, 8, 9], + } + df = nw.from_native(constructor(data)) + df_right = df + result = df.join( + df_right, # type: ignore[arg-type] + left_on=["antananarivo", "bob"], + right_on=["antananarivo", "bob"], + how=how, # type: ignore[arg-type] + suffix=suffix, + ) + result_cols = result.collect_schema().names() + assert result_cols == ["antananarivo", "bob", "zorro", f"zorro{suffix}"] + + +@pytest.mark.parametrize("suffix", ["_right", "_custom_suffix"]) +def test_cross_join_suffix(constructor: Any, suffix: str) -> None: + data = {"antananarivo": [1, 3, 2]} + df = nw.from_native(constructor(data)) + result = df.join(df, how="cross", suffix=suffix).sort( # type: ignore[arg-type] + "antananarivo", f"antananarivo{suffix}" + ) + expected = { + "antananarivo": [1, 1, 1, 2, 2, 2, 3, 3, 3], + f"antananarivo{suffix}": [1, 2, 3, 1, 2, 3, 1, 2, 3], + } + compare_dicts(result, expected) def test_cross_join_non_pandas() -> None: - data = {"a": [1, 3, 2]} + data = {"antananarivo": [1, 3, 2]} df = nw.from_native(pd.DataFrame(data)) # HACK to force testing for a non-pandas codepath df._compliant_frame._implementation = Implementation.MODIN result = df.join(df, how="cross") # type: ignore[arg-type] expected = { - "a": [1, 1, 1, 3, 3, 3, 2, 2, 2], - "a_right": [1, 3, 2, 1, 3, 2, 1, 3, 2], + "antananarivo": [1, 1, 1, 3, 3, 3, 2, 2, 2], + "antananarivo_right": [1, 3, 2, 1, 3, 2, 1, 3, 2], } compare_dicts(result, expected) @@ -70,9 +140,17 @@ def test_cross_join_non_pandas() -> None: @pytest.mark.parametrize( ("join_key", "filter_expr", "expected"), [ - (["a", "b"], (nw.col("b") < 5), {"a": [2], "b": [6], "z": [9]}), - (["b"], (nw.col("b") < 5), {"a": [2], "b": [6], "z": [9]}), - (["b"], (nw.col("b") > 5), {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]}), + ( + ["antananarivo", "bob"], + (nw.col("bob") < 5), + {"antananarivo": [2], "bob": [6], "zorro": [9]}, + ), + (["bob"], (nw.col("bob") < 5), {"antananarivo": [2], "bob": [6], "zorro": [9]}), + ( + ["bob"], + (nw.col("bob") > 5), + {"antananarivo": [1, 3], "bob": [4, 4], "zorro": [7.0, 8.0]}, + ), ], ) def test_anti_join( @@ -81,7 +159,7 @@ def test_anti_join( filter_expr: nw.Expr, expected: dict[str, list[Any]], ) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} df = nw.from_native(constructor(data)) other = df.filter(filter_expr) result = df.join(other, how="anti", left_on=join_key, right_on=join_key) # type: ignore[arg-type] @@ -91,9 +169,21 @@ def test_anti_join( @pytest.mark.parametrize( ("join_key", "filter_expr", "expected"), [ - (["a"], (nw.col("b") > 5), {"a": [2], "b": [6], "z": [9]}), - (["b"], (nw.col("b") < 5), {"a": [1, 3], "b": [4, 4], "z": [7, 8]}), - (["a", "b"], (nw.col("b") < 5), {"a": [1, 3], "b": [4, 4], "z": [7, 8]}), + ( + ["antananarivo"], + (nw.col("bob") > 5), + {"antananarivo": [2], "bob": [6], "zorro": [9]}, + ), + ( + ["bob"], + (nw.col("bob") < 5), + {"antananarivo": [1, 3], "bob": [4, 4], "zorro": [7, 8]}, + ), + ( + ["antananarivo", "bob"], + (nw.col("bob") < 5), + {"antananarivo": [1, 3], "bob": [4, 4], "zorro": [7, 8]}, + ), ], ) def test_semi_join( @@ -102,74 +192,400 @@ def test_semi_join( filter_expr: nw.Expr, expected: dict[str, list[Any]], ) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} df = nw.from_native(constructor(data)) other = df.filter(filter_expr) - result = df.join(other, how="semi", left_on=join_key, right_on=join_key) # type: ignore[arg-type] + result = df.join(other, how="semi", left_on=join_key, right_on=join_key).sort( # type: ignore[arg-type] + "antananarivo" + ) compare_dicts(result, expected) @pytest.mark.parametrize("how", ["right", "full"]) def test_join_not_implemented(constructor: Any, how: str) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} df = nw.from_native(constructor(data)) with pytest.raises( NotImplementedError, match=re.escape( - f"Only the following join stragies are supported: ('inner', 'left', 'cross', 'anti', 'semi'); found '{how}'." + f"Only the following join strategies are supported: ('inner', 'left', 'cross', 'anti', 'semi'); found '{how}'." ), ): - df.join(df, left_on="a", right_on="a", how=how) # type: ignore[arg-type] + df.join(df, left_on="antananarivo", right_on="antananarivo", how=how) # type: ignore[arg-type] @pytest.mark.filterwarnings("ignore:the default coalesce behavior") def test_left_join(constructor: Any) -> None: - data_left = {"a": [1.0, 2, 3], "b": [4.0, 5, 6]} - data_right = {"a": [1.0, 2, 3], "c": [4.0, 5, 7]} + data_left = { + "antananarivo": [1.0, 2, 3], + "bob": [4.0, 5, 6], + "index": [0.0, 1.0, 2.0], + } + data_right = {"antananarivo": [1.0, 2, 3], "c": [4.0, 5, 7], "index": [0.0, 1.0, 2.0]} df_left = nw.from_native(constructor(data_left)) df_right = nw.from_native(constructor(data_right)) - result = df_left.join(df_right, left_on="b", right_on="c", how="left").select( # type: ignore[arg-type] + result = df_left.join(df_right, left_on="bob", right_on="c", how="left").select( # type: ignore[arg-type] nw.all().fill_null(float("nan")) ) - expected = {"a": [1, 2, 3], "b": [4, 5, 6], "a_right": [1, 2, float("nan")]} + result = result.sort("index") + result = result.drop("index_right") + expected = { + "antananarivo": [1, 2, 3], + "bob": [4, 5, 6], + "antananarivo_right": [1, 2, float("nan")], + "index": [0, 1, 2], + } compare_dicts(result, expected) @pytest.mark.filterwarnings("ignore: the default coalesce behavior") def test_left_join_multiple_column(constructor: Any) -> None: - data_left = {"a": [1, 2, 3], "b": [4, 5, 6]} - data_right = {"a": [1, 2, 3], "c": [4, 5, 6]} + data_left = {"antananarivo": [1, 2, 3], "bob": [4, 5, 6], "index": [0, 1, 2]} + data_right = {"antananarivo": [1, 2, 3], "c": [4, 5, 6], "index": [0, 1, 2]} df_left = nw.from_native(constructor(data_left)) df_right = nw.from_native(constructor(data_right)) - result = df_left.join(df_right, left_on=["a", "b"], right_on=["a", "c"], how="left") # type: ignore[arg-type] - expected = {"a": [1, 2, 3], "b": [4, 5, 6]} + result = df_left.join( + df_right, # type: ignore[arg-type] + left_on=["antananarivo", "bob"], + right_on=["antananarivo", "c"], + how="left", + ) + result = result.sort("index") + result = result.drop("index_right") + expected = {"antananarivo": [1, 2, 3], "bob": [4, 5, 6], "index": [0, 1, 2]} compare_dicts(result, expected) @pytest.mark.filterwarnings("ignore: the default coalesce behavior") def test_left_join_overlapping_column(constructor: Any) -> None: - data_left = {"a": [1.0, 2, 3], "b": [4.0, 5, 6], "d": [1.0, 4, 2]} - data_right = {"a": [1.0, 2, 3], "c": [4.0, 5, 6], "d": [1.0, 4, 2]} + data_left = { + "antananarivo": [1.0, 2, 3], + "bob": [4.0, 5, 6], + "d": [1.0, 4, 2], + "index": [0.0, 1.0, 2.0], + } + data_right = { + "antananarivo": [1.0, 2, 3], + "c": [4.0, 5, 6], + "d": [1.0, 4, 2], + "index": [0.0, 1.0, 2.0], + } df_left = nw.from_native(constructor(data_left)) df_right = nw.from_native(constructor(data_right)) - result = df_left.join(df_right, left_on="b", right_on="c", how="left") # type: ignore[arg-type] + result = df_left.join(df_right, left_on="bob", right_on="c", how="left").sort("index") # type: ignore[arg-type] + result = result.drop("index_right") expected: dict[str, list[Any]] = { - "a": [1, 2, 3], - "b": [4, 5, 6], + "antananarivo": [1, 2, 3], + "bob": [4, 5, 6], "d": [1, 4, 2], - "a_right": [1, 2, 3], + "antananarivo_right": [1, 2, 3], "d_right": [1, 4, 2], + "index": [0, 1, 2], } compare_dicts(result, expected) - result = df_left.join(df_right, left_on="a", right_on="d", how="left").select( # type: ignore[arg-type] - nw.all().fill_null(float("nan")) - ) + result = df_left.join( + df_right, # type: ignore[arg-type] + left_on="antananarivo", + right_on="d", + how="left", + ).select(nw.all().fill_null(float("nan"))) + result = result.sort("index") + result = result.drop("index_right") expected = { - "a": [1, 2, 3], - "b": [4, 5, 6], + "antananarivo": [1, 2, 3], + "bob": [4, 5, 6], "d": [1, 4, 2], - "a_right": [1.0, 3.0, float("nan")], + "antananarivo_right": [1.0, 3.0, float("nan")], "c": [4.0, 6.0, float("nan")], + "index": [0, 1, 2], } compare_dicts(result, expected) + + +@pytest.mark.parametrize("how", ["inner", "left", "semi", "anti"]) +def test_join_keys_exceptions(constructor: Any, how: str) -> None: + data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} + df = nw.from_native(constructor(data)) + + with pytest.raises( + ValueError, + match=rf"Either \(`left_on` and `right_on`\) or `on` keys should be specified for {how}.", + ): + df.join(df, how=how) # type: ignore[arg-type] + with pytest.raises( + ValueError, + match=rf"Either \(`left_on` and `right_on`\) or `on` keys should be specified for {how}.", + ): + df.join(df, how=how, left_on="antananarivo") # type: ignore[arg-type] + with pytest.raises( + ValueError, + match=rf"Either \(`left_on` and `right_on`\) or `on` keys should be specified for {how}.", + ): + df.join(df, how=how, right_on="antananarivo") # type: ignore[arg-type] + with pytest.raises( + ValueError, + match=f"If `on` is specified, `left_on` and `right_on` should be None for {how}.", + ): + df.join(df, how=how, on="antananarivo", right_on="antananarivo") # type: ignore[arg-type] + + +def test_joinasof_numeric(constructor: Any, request: Any) -> None: + if "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + if parse_version(pd.__version__) < (2, 1) and ( + ("pandas_pyarrow" in str(constructor)) or ("pandas_nullable" in str(constructor)) + ): + request.applymarker(pytest.mark.xfail) + df = nw.from_native( + constructor({"antananarivo": [1, 5, 10], "val": ["a", "b", "c"]}) + ).sort("antananarivo") + df_right = nw.from_native( + constructor({"antananarivo": [1, 2, 3, 6, 7], "val": [1, 2, 3, 6, 7]}) + ).sort("antananarivo") + result_backward = df.join_asof( + df_right, # type: ignore[arg-type] + left_on="antananarivo", + right_on="antananarivo", + ) + result_forward = df.join_asof( + df_right, # type: ignore[arg-type] + left_on="antananarivo", + right_on="antananarivo", + strategy="forward", + ) + result_nearest = df.join_asof( + df_right, # type: ignore[arg-type] + left_on="antananarivo", + right_on="antananarivo", + strategy="nearest", + ) + result_backward_on = df.join_asof(df_right, on="antananarivo") # type: ignore[arg-type] + result_forward_on = df.join_asof(df_right, on="antananarivo", strategy="forward") # type: ignore[arg-type] + result_nearest_on = df.join_asof(df_right, on="antananarivo", strategy="nearest") # type: ignore[arg-type] + expected_backward = { + "antananarivo": [1, 5, 10], + "val": ["a", "b", "c"], + "val_right": [1, 3, 7], + } + expected_forward = { + "antananarivo": [1, 5, 10], + "val": ["a", "b", "c"], + "val_right": [1, 6, float("nan")], + } + expected_nearest = { + "antananarivo": [1, 5, 10], + "val": ["a", "b", "c"], + "val_right": [1, 6, 7], + } + compare_dicts(result_backward, expected_backward) + compare_dicts(result_forward, expected_forward) + compare_dicts(result_nearest, expected_nearest) + compare_dicts(result_backward_on, expected_backward) + compare_dicts(result_forward_on, expected_forward) + compare_dicts(result_nearest_on, expected_nearest) + + +def test_joinasof_time(constructor: Any, request: Any) -> None: + if "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + if parse_version(pd.__version__) < (2, 1) and ("pandas_pyarrow" in str(constructor)): + request.applymarker(pytest.mark.xfail) + df = nw.from_native( + constructor( + { + "datetime": [ + datetime(2016, 3, 1), + datetime(2018, 8, 1), + datetime(2019, 1, 1), + ], + "population": [82.19, 82.66, 83.12], + } + ) + ).sort("datetime") + df_right = nw.from_native( + constructor( + { + "datetime": [ + datetime(2016, 1, 1), + datetime(2017, 1, 1), + datetime(2018, 1, 1), + datetime(2019, 1, 1), + datetime(2020, 1, 1), + ], + "gdp": [4164, 4411, 4566, 4696, 4827], + } + ) + ).sort("datetime") + result_backward = df.join_asof(df_right, left_on="datetime", right_on="datetime") # type: ignore[arg-type] + result_forward = df.join_asof( + df_right, # type: ignore[arg-type] + left_on="datetime", + right_on="datetime", + strategy="forward", + ) + result_nearest = df.join_asof( + df_right, # type: ignore[arg-type] + left_on="datetime", + right_on="datetime", + strategy="nearest", + ) + result_backward_on = df.join_asof(df_right, on="datetime") # type: ignore[arg-type] + result_forward_on = df.join_asof( + df_right, # type: ignore[arg-type] + on="datetime", + strategy="forward", + ) + result_nearest_on = df.join_asof( + df_right, # type: ignore[arg-type] + on="datetime", + strategy="nearest", + ) + expected_backward = { + "datetime": [datetime(2016, 3, 1), datetime(2018, 8, 1), datetime(2019, 1, 1)], + "population": [82.19, 82.66, 83.12], + "gdp": [4164, 4566, 4696], + } + expected_forward = { + "datetime": [datetime(2016, 3, 1), datetime(2018, 8, 1), datetime(2019, 1, 1)], + "population": [82.19, 82.66, 83.12], + "gdp": [4411, 4696, 4696], + } + expected_nearest = { + "datetime": [datetime(2016, 3, 1), datetime(2018, 8, 1), datetime(2019, 1, 1)], + "population": [82.19, 82.66, 83.12], + "gdp": [4164, 4696, 4696], + } + compare_dicts(result_backward, expected_backward) + compare_dicts(result_forward, expected_forward) + compare_dicts(result_nearest, expected_nearest) + compare_dicts(result_backward_on, expected_backward) + compare_dicts(result_forward_on, expected_forward) + compare_dicts(result_nearest_on, expected_nearest) + + +def test_joinasof_by(constructor: Any, request: Any) -> None: + if "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + if parse_version(pd.__version__) < (2, 1) and ( + ("pandas_pyarrow" in str(constructor)) or ("pandas_nullable" in str(constructor)) + ): + request.applymarker(pytest.mark.xfail) + df = nw.from_native( + constructor( + { + "antananarivo": [1, 5, 7, 10], + "bob": ["D", "D", "C", "A"], + "c": [9, 2, 1, 1], + } + ) + ).sort("antananarivo") + df_right = nw.from_native( + constructor( + {"antananarivo": [1, 4, 5, 8], "bob": ["D", "D", "A", "F"], "d": [1, 3, 4, 1]} + ) + ).sort("antananarivo") + result = df.join_asof(df_right, on="antananarivo", by_left="bob", by_right="bob") # type: ignore[arg-type] + result_by = df.join_asof(df_right, on="antananarivo", by="bob") # type: ignore[arg-type] + expected = { + "antananarivo": [1, 5, 7, 10], + "bob": ["D", "D", "C", "A"], + "c": [9, 2, 1, 1], + "d": [1, 3, float("nan"), 4], + } + compare_dicts(result, expected) + compare_dicts(result_by, expected) + + +@pytest.mark.parametrize("strategy", ["back", "furthest"]) +def test_joinasof_not_implemented( + constructor: Any, strategy: Literal["backward", "forward"] +) -> None: + data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} + df = nw.from_native(constructor(data)) + + with pytest.raises( + NotImplementedError, + match=rf"Only the following strategies are supported: \('backward', 'forward', 'nearest'\); found '{strategy}'.", + ): + df.join_asof( + df, # type: ignore[arg-type] + left_on="antananarivo", + right_on="antananarivo", + strategy=strategy, + ) + + +def test_joinasof_keys_exceptions(constructor: Any) -> None: + data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} + df = nw.from_native(constructor(data)) + + with pytest.raises( + ValueError, + match=r"Either \(`left_on` and `right_on`\) or `on` keys should be specified.", + ): + df.join_asof(df, left_on="antananarivo") # type: ignore[arg-type] + with pytest.raises( + ValueError, + match=r"Either \(`left_on` and `right_on`\) or `on` keys should be specified.", + ): + df.join_asof(df, right_on="antananarivo") # type: ignore[arg-type] + with pytest.raises( + ValueError, + match=r"Either \(`left_on` and `right_on`\) or `on` keys should be specified.", + ): + df.join_asof(df) # type: ignore[arg-type] + with pytest.raises( + ValueError, + match="If `on` is specified, `left_on` and `right_on` should be None.", + ): + df.join_asof( + df, # type: ignore[arg-type] + left_on="antananarivo", + right_on="antananarivo", + on="antananarivo", + ) + with pytest.raises( + ValueError, + match="If `on` is specified, `left_on` and `right_on` should be None.", + ): + df.join_asof(df, left_on="antananarivo", on="antananarivo") # type: ignore[arg-type] + with pytest.raises( + ValueError, + match="If `on` is specified, `left_on` and `right_on` should be None.", + ): + df.join_asof(df, right_on="antananarivo", on="antananarivo") # type: ignore[arg-type] + + +def test_joinasof_by_exceptions(constructor: Any) -> None: + data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} + df = nw.from_native(constructor(data)) + with pytest.raises( + ValueError, + match="If `by` is specified, `by_left` and `by_right` should be None.", + ): + df.join_asof(df, on="antananarivo", by_left="bob", by_right="bob", by="bob") # type: ignore[arg-type] + + with pytest.raises( + ValueError, + match="Can not specify only `by_left` or `by_right`, you need to specify both.", + ): + df.join_asof(df, on="antananarivo", by_left="bob") # type: ignore[arg-type] + + with pytest.raises( + ValueError, + match="Can not specify only `by_left` or `by_right`, you need to specify both.", + ): + df.join_asof(df, on="antananarivo", by_right="bob") # type: ignore[arg-type] + + with pytest.raises( + ValueError, + match="If `by` is specified, `by_left` and `by_right` should be None.", + ): + df.join_asof(df, on="antananarivo", by_left="bob", by="bob") # type: ignore[arg-type] + + with pytest.raises( + ValueError, + match="If `by` is specified, `by_left` and `by_right` should be None.", + ): + df.join_asof(df, on="antananarivo", by_right="bob", by="bob") # type: ignore[arg-type] diff --git a/tests/frame/lit_test.py b/tests/frame/lit_test.py index 328e4d8e0..e5756e035 100644 --- a/tests/frame/lit_test.py +++ b/tests/frame/lit_test.py @@ -17,11 +17,7 @@ ("dtype", "expected_lit"), [(None, [2, 2, 2]), (nw.String, ["2", "2", "2"]), (nw.Float32, [2.0, 2.0, 2.0])], ) -def test_lit( - constructor: Any, dtype: DType | None, expected_lit: list[Any], request: Any -) -> None: - if "dask" in str(constructor) and dtype == nw.String: - request.applymarker(pytest.mark.xfail) +def test_lit(constructor: Any, dtype: DType | None, expected_lit: list[Any]) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_raw = constructor(data) df = nw.from_native(df_raw).lazy() diff --git a/tests/frame/slice_test.py b/tests/frame/slice_test.py index eea94d440..834e88bff 100644 --- a/tests/frame/slice_test.py +++ b/tests/frame/slice_test.py @@ -141,6 +141,15 @@ def test_slice_slice_columns(constructor_eager: Any) -> None: result = df[[0, 1], 1:] expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]} compare_dicts(result, expected) + result = df[:, ["b", "d"]] + expected = {"b": [4, 5, 6], "d": [1, 4, 2]} + compare_dicts(result, expected) + result = df[:, [0, 2]] + expected = {"a": [1, 2, 3], "c": [7, 8, 9]} + compare_dicts(result, expected) + result = df[["b", "c"]] + expected = {"b": [4, 5, 6], "c": [7, 8, 9]} + compare_dicts(result, expected) def test_slice_invalid(constructor_eager: Any) -> None: diff --git a/tests/frame/tail_test.py b/tests/frame/tail_test.py index e279caba9..b64d9fa6c 100644 --- a/tests/frame/tail_test.py +++ b/tests/frame/tail_test.py @@ -1,7 +1,10 @@ from __future__ import annotations +from contextlib import nullcontext as does_not_raise from typing import Any +import pytest + import narwhals.stable.v1 as nw from tests.utils import compare_dicts @@ -13,14 +16,24 @@ def test_tail(constructor: Any) -> None: df_raw = constructor(data) df = nw.from_native(df_raw).lazy() - result = df.tail(2) - compare_dicts(result, expected) + context = ( + pytest.raises( + NotImplementedError, + match="`LazyFrame.tail` is not supported for Dask backend with multiple partitions.", + ) + if "dask_lazy_p2" in str(constructor) + else does_not_raise() + ) + + with context: + result = df.tail(2) + compare_dicts(result, expected) - result = df.collect().tail(2) # type: ignore[assignment] - compare_dicts(result, expected) + result = df.collect().tail(2) # type: ignore[assignment] + compare_dicts(result, expected) - result = df.collect().tail(-1) # type: ignore[assignment] - compare_dicts(result, expected) + result = df.collect().tail(-1) # type: ignore[assignment] + compare_dicts(result, expected) - result = df.collect().select(nw.col("a").tail(2)) # type: ignore[assignment] - compare_dicts(result, {"a": expected["a"]}) + result = df.collect().select(nw.col("a").tail(2)) # type: ignore[assignment] + compare_dicts(result, {"a": expected["a"]}) diff --git a/tests/frame/test_invalid.py b/tests/frame/test_invalid.py index cf1fff6d1..b8bca586f 100644 --- a/tests/frame/test_invalid.py +++ b/tests/frame/test_invalid.py @@ -24,7 +24,7 @@ def test_validate_laziness() -> None: NotImplementedError, match=("The items to concatenate should either all be eager, or all lazy"), ): - nw.concat([nw.from_native(df, eager_only=True), nw.from_native(df).lazy()]) + nw.concat([nw.from_native(df, eager_only=True), nw.from_native(df).lazy()]) # type: ignore[list-item] @pytest.mark.skipif( diff --git a/tests/hypothesis/test_basic_arithmetic.py b/tests/hypothesis/test_basic_arithmetic.py index 2ab7bad7b..00818271d 100644 --- a/tests/hypothesis/test_basic_arithmetic.py +++ b/tests/hypothesis/test_basic_arithmetic.py @@ -22,7 +22,7 @@ max_size=3, ), ) # type: ignore[misc] -@pytest.mark.slow() +@pytest.mark.slow def test_mean( integer: st.SearchStrategy[list[int]], floats: st.SearchStrategy[float], diff --git a/tests/hypothesis/test_concat.py b/tests/hypothesis/test_concat.py index 1b1248628..9ae54dbc4 100644 --- a/tests/hypothesis/test_concat.py +++ b/tests/hypothesis/test_concat.py @@ -31,7 +31,7 @@ ), how=st.sampled_from(["horizontal", "vertical"]), ) # type: ignore[misc] -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.skipif(is_windows(), reason="pyarrow breaking on windows") def test_concat( # pragma: no cover integers: list[int], diff --git a/tests/hypothesis/test_join.py b/tests/hypothesis/test_join.py index ebdb88757..bc1cd735c 100644 --- a/tests/hypothesis/test_join.py +++ b/tests/hypothesis/test_join.py @@ -42,7 +42,7 @@ ) # type: ignore[misc] @pytest.mark.skipif(pl_version < parse_version("0.20.13"), reason="0.0 == -0.0") @pytest.mark.skipif(pd_version < parse_version("2.0.0"), reason="requires pyarrow") -@pytest.mark.slow() +@pytest.mark.slow def test_join( # pragma: no cover integers: st.SearchStrategy[list[int]], other_integers: st.SearchStrategy[list[int]], @@ -88,7 +88,7 @@ def test_join( # pragma: no cover max_size=3, ), ) # type: ignore[misc] -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.skipif(pd_version < parse_version("2.0.0"), reason="requires pyarrow") def test_cross_join( # pragma: no cover integers: st.SearchStrategy[list[int]], @@ -135,7 +135,7 @@ def test_cross_join( # pragma: no cover st.sampled_from(["a", "b", "d"]), min_size=1, max_size=3, unique=True ), ) -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.filterwarnings("ignore:the default coalesce behavior") def test_left_join( # pragma: no cover a_left_data: list[int], diff --git a/tests/no_imports_test.py b/tests/no_imports_test.py index a89ed0ed8..b30545380 100644 --- a/tests/no_imports_test.py +++ b/tests/no_imports_test.py @@ -13,6 +13,7 @@ def test_polars(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delitem(sys.modules, "numpy") monkeypatch.delitem(sys.modules, "pyarrow") monkeypatch.delitem(sys.modules, "dask", raising=False) + monkeypatch.delitem(sys.modules, "ibis", raising=False) df = pl.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) nw.from_native(df, eager_only=True).group_by("a").agg(nw.col("b").mean()).filter( nw.col("a") > 1 @@ -22,12 +23,14 @@ def test_polars(monkeypatch: pytest.MonkeyPatch) -> None: assert "numpy" not in sys.modules assert "pyarrow" not in sys.modules assert "dask" not in sys.modules + assert "ibis" not in sys.modules def test_pandas(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delitem(sys.modules, "polars") monkeypatch.delitem(sys.modules, "pyarrow") monkeypatch.delitem(sys.modules, "dask", raising=False) + monkeypatch.delitem(sys.modules, "ibis", raising=False) df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) nw.from_native(df, eager_only=True).group_by("a").agg(nw.col("b").mean()).filter( nw.col("a") > 1 @@ -37,6 +40,7 @@ def test_pandas(monkeypatch: pytest.MonkeyPatch) -> None: assert "numpy" in sys.modules assert "pyarrow" not in sys.modules assert "dask" not in sys.modules + assert "ibis" not in sys.modules def test_dask(monkeypatch: pytest.MonkeyPatch) -> None: @@ -59,6 +63,7 @@ def test_pyarrow(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delitem(sys.modules, "polars") monkeypatch.delitem(sys.modules, "pandas") monkeypatch.delitem(sys.modules, "dask", raising=False) + monkeypatch.delitem(sys.modules, "ibis", raising=False) df = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}) nw.from_native(df).group_by("a").agg(nw.col("b").mean()).filter(nw.col("a") > 1) assert "polars" not in sys.modules @@ -66,3 +71,4 @@ def test_pyarrow(monkeypatch: pytest.MonkeyPatch) -> None: assert "numpy" in sys.modules assert "pyarrow" in sys.modules assert "dask" not in sys.modules + assert "ibis" not in sys.modules diff --git a/tests/series_only/arrow_c_stream_test.py b/tests/series_only/arrow_c_stream_test.py index 9964d7408..9d2ebc8d0 100644 --- a/tests/series_only/arrow_c_stream_test.py +++ b/tests/series_only/arrow_c_stream_test.py @@ -10,6 +10,9 @@ @pytest.mark.skipif( parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" ) +@pytest.mark.skipif( + parse_version(pa.__version__) < (16, 0, 0), reason="too old for pycapsule in PyArrow" +) def test_arrow_c_stream_test() -> None: s = nw.from_native(pl.Series([1, 2, 3]), series_only=True) result = pa.chunked_array(s) @@ -20,6 +23,9 @@ def test_arrow_c_stream_test() -> None: @pytest.mark.skipif( parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" ) +@pytest.mark.skipif( + parse_version(pa.__version__) < (16, 0, 0), reason="too old for pycapsule in PyArrow" +) def test_arrow_c_stream_test_invalid(monkeypatch: pytest.MonkeyPatch) -> None: # "poison" the dunder method to make sure it actually got called above monkeypatch.setattr("narwhals.series.Series.__arrow_c_stream__", lambda *_: 1 / 0) @@ -31,6 +37,9 @@ def test_arrow_c_stream_test_invalid(monkeypatch: pytest.MonkeyPatch) -> None: @pytest.mark.skipif( parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" ) +@pytest.mark.skipif( + parse_version(pa.__version__) < (16, 0, 0), reason="too old for pycapsule in PyArrow" +) def test_arrow_c_stream_test_fallback(monkeypatch: pytest.MonkeyPatch) -> None: # Check that fallback to PyArrow works monkeypatch.delattr("polars.Series.__arrow_c_stream__") diff --git a/tests/series_only/scatter_test.py b/tests/series_only/scatter_test.py new file mode 100644 index 000000000..2edab2b8c --- /dev/null +++ b/tests/series_only/scatter_test.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import narwhals as nw +from tests.utils import compare_dicts + + +def test_scatter(constructor_eager: Any, request: pytest.FixtureRequest) -> None: + if "modin" in str(constructor_eager): + # https://github.com/modin-project/modin/issues/7392 + request.applymarker(pytest.mark.xfail) + df = nw.from_native( + constructor_eager({"a": [1, 2, 3], "b": [142, 124, 132]}), eager_only=True + ) + result = df.with_columns( + df["a"].scatter([0, 1], [999, 888]), + df["b"].scatter([0, 2, 1], df["b"]), + ) + expected = { + "a": [999, 888, 3], + "b": [142, 132, 124], + } + compare_dicts(result, expected) diff --git a/tests/series_only/slice_test.py b/tests/series_only/slice_test.py new file mode 100644 index 000000000..f9d2b4e2f --- /dev/null +++ b/tests/series_only/slice_test.py @@ -0,0 +1,15 @@ +from typing import Any + +import narwhals.stable.v1 as nw +from tests.utils import compare_dicts + + +def test_slice(constructor_eager: Any) -> None: + data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [1, 4, 2]} + df = nw.from_native(constructor_eager(data), eager_only=True) + result = {"a": df["a"][[0, 1]]} + expected = {"a": [1, 2]} + compare_dicts(result, expected) + result = {"a": df["a"][1:]} + expected = {"a": [2, 3]} + compare_dicts(result, expected) diff --git a/tests/series_only/to_dummy_test.py b/tests/series_only/to_dummy_test.py index 2cf7f59c7..404ac6321 100644 --- a/tests/series_only/to_dummy_test.py +++ b/tests/series_only/to_dummy_test.py @@ -18,7 +18,9 @@ def test_to_dummies(constructor_eager: Any, sep: str) -> None: @pytest.mark.parametrize("sep", ["_", "-"]) -def test_to_dummies_drop_first(constructor_eager: Any, sep: str) -> None: +def test_to_dummies_drop_first(request: Any, constructor_eager: Any, sep: str) -> None: + if "cudf" in str(constructor_eager): + request.applymarker(pytest.mark.xfail) s = nw.from_native(constructor_eager({"a": data}), eager_only=True)["a"].alias("a") result = s.to_dummies(drop_first=True, separator=sep) expected = {f"a{sep}2": [0, 1, 0], f"a{sep}3": [0, 0, 1]} diff --git a/tests/test_group_by.py b/tests/test_group_by.py index 2bb8d435b..6f12d06b1 100644 --- a/tests/test_group_by.py +++ b/tests/test_group_by.py @@ -102,6 +102,57 @@ def test_group_by_len(constructor: Any) -> None: compare_dicts(result, expected) +def test_group_by_n_unique(constructor: Any) -> None: + result = ( + nw.from_native(constructor(data)) + .group_by("a") + .agg(nw.col("b").n_unique()) + .sort("a") + ) + expected = {"a": [1, 3], "b": [1, 1]} + compare_dicts(result, expected) + + +def test_group_by_std(constructor: Any) -> None: + data = {"a": [1, 1, 2, 2], "b": [5, 4, 3, 2]} + result = ( + nw.from_native(constructor(data)).group_by("a").agg(nw.col("b").std()).sort("a") + ) + expected = {"a": [1, 2], "b": [0.707107] * 2} + compare_dicts(result, expected) + + +def test_group_by_n_unique_w_missing(constructor: Any) -> None: + data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]} + result = ( + nw.from_native(constructor(data)) + .group_by("a") + .agg( + nw.col("b").n_unique(), + c_n_unique=nw.col("c").n_unique(), + c_n_min=nw.col("b").min(), + d_n_unique=nw.col("d").n_unique(), + ) + .sort("a") + ) + expected = { + "a": [1, 2], + "b": [2, 1], + "c_n_unique": [1, 1], + "c_n_min": [4, 5], + "d_n_unique": [1, 1], + } + compare_dicts(result, expected) + + +def test_group_by_same_name_twice() -> None: + import pandas as pd + + df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) + with pytest.raises(ValueError, match="two aggregations with the same"): + nw.from_native(df).group_by("a").agg(nw.col("b").sum(), nw.col("b").n_unique()) + + def test_group_by_empty_result_pandas() -> None: df_any = pd.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2]}) df = nw.from_native(df_any, eager_only=True) @@ -195,3 +246,10 @@ def test_key_with_nulls(constructor: Any, request: Any) -> None: ) expected = {"b": [4.0, 5, float("nan")], "len": [1, 1, 1], "a": [1, 2, 3]} compare_dicts(result, expected) + + +def test_no_agg(constructor: Any) -> None: + result = nw.from_native(constructor(data)).group_by(["a", "b"]).agg().sort("a", "b") + + expected = {"a": [1, 3], "b": [4, 6]} + compare_dicts(result, expected) diff --git a/tests/test_utils.py b/tests/test_utils.py index 68dc90ed7..f51c28eab 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,15 +15,15 @@ def test_maybe_align_index_pandas() -> None: result = nw.maybe_align_index(df, s) expected = pd.DataFrame({"a": [2, 1, 3]}, index=[2, 1, 0]) assert_frame_equal(nw.to_native(result), expected) - result = nw.maybe_align_index(s, df) - expected = pd.Series([2, 1, 3], index=[1, 2, 0]) - assert_series_equal(nw.to_native(result), expected) - result = nw.maybe_align_index(s, s.sort(descending=True)) - expected = pd.Series([3, 2, 1], index=[0, 1, 2]) - assert_series_equal(nw.to_native(result), expected) result = nw.maybe_align_index(df, df.sort("a", descending=True)) expected = pd.DataFrame({"a": [3, 2, 1]}, index=[0, 2, 1]) assert_frame_equal(nw.to_native(result), expected) + result_s = nw.maybe_align_index(s, df) + expected_s = pd.Series([2, 1, 3], index=[1, 2, 0]) + assert_series_equal(nw.to_native(result_s), expected_s) + result_s = nw.maybe_align_index(s, s.sort(descending=True)) + expected_s = pd.Series([3, 2, 1], index=[0, 1, 2]) + assert_series_equal(nw.to_native(result_s), expected_s) def test_with_columns_sort() -> None: diff --git a/tests/translate/from_native_test.py b/tests/translate/from_native_test.py index af49c0226..8ac33b620 100644 --- a/tests/translate/from_native_test.py +++ b/tests/translate/from_native_test.py @@ -199,3 +199,14 @@ def test_eager_only_lazy_dask(eager_only: Any, context: Any) -> None: with context: res = nw.from_native(dframe, eager_only=eager_only) assert isinstance(res, nw.LazyFrame) + + +def test_from_native_strict_false_typing() -> None: + df = pl.DataFrame() + nw.from_native(df, strict=False) + nw.from_native(df, strict=False, eager_only=True) + nw.from_native(df, strict=False, eager_or_interchange_only=True) + + unstable_nw.from_native(df, strict=False) + unstable_nw.from_native(df, strict=False, eager_only=True) + unstable_nw.from_native(df, strict=False, eager_or_interchange_only=True) diff --git a/tpch/README.md b/tpch/README.md new file mode 100644 index 000000000..3ae09b723 --- /dev/null +++ b/tpch/README.md @@ -0,0 +1,17 @@ +# Narwhals TPC-H queries + +Utilities for running the TPC-H queries via Narwhals. + +Before getting started, make sure you've followed the instructions in +`CONTRIBUTING.MD`. + +## Generate data + +Run `python generate_data.py` from this folder. + +## Run queries + +To run Q1, you can run `python -m execute.q1` from this folder. + +Please add query definitions in `queries`, and scripts to execute them +in `execute` (see `queries/q1.py` and `execute/q1.py` for examples). diff --git a/tpch/execute/__init__.py b/tpch/execute/__init__.py new file mode 100644 index 000000000..e0c448649 --- /dev/null +++ b/tpch/execute/__init__.py @@ -0,0 +1,30 @@ +from pathlib import Path + +import dask.dataframe as dd +import pandas as pd +import polars as pl +import pyarrow.parquet as pq + +pd.options.mode.copy_on_write = True +pd.options.future.infer_string = True + +lineitem = Path("data") / "lineitem.parquet" +region = Path("data") / "region.parquet" +nation = Path("data") / "nation.parquet" +supplier = Path("data") / "supplier.parquet" +part = Path("data") / "part.parquet" +partsupp = Path("data") / "partsupp.parquet" +orders = Path("data") / "orders.parquet" +customer = Path("data") / "customer.parquet" +line_item = Path("data") / "lineitem.parquet" + +IO_FUNCS = { + "pandas": lambda x: pd.read_parquet(x, engine="pyarrow"), + "pandas[pyarrow]": lambda x: pd.read_parquet( + x, engine="pyarrow", dtype_backend="pyarrow" + ), + "polars[eager]": lambda x: pl.read_parquet(x), + "polars[lazy]": lambda x: pl.scan_parquet(x), + "pyarrow": lambda x: pq.read_table(x), + "dask": lambda x: dd.read_parquet(x, engine="pyarrow", dtype_backend="pyarrow"), +} diff --git a/tpch/execute/q1.py b/tpch/execute/q1.py new file mode 100644 index 000000000..9889c3af0 --- /dev/null +++ b/tpch/execute/q1.py @@ -0,0 +1,9 @@ +from queries import q1 + +from . import IO_FUNCS +from . import lineitem + +print(q1.query(IO_FUNCS["pandas[pyarrow]"](lineitem))) +print(q1.query(IO_FUNCS["polars[lazy]"](lineitem)).collect()) +print(q1.query(IO_FUNCS["pyarrow"](lineitem))) +print(q1.query(IO_FUNCS["dask"](lineitem)).compute()) diff --git a/tpch/execute/q10.py b/tpch/execute/q10.py new file mode 100644 index 000000000..e1d56d36b --- /dev/null +++ b/tpch/execute/q10.py @@ -0,0 +1,19 @@ +from queries import q10 + +from . import IO_FUNCS +from . import customer +from . import lineitem +from . import nation +from . import orders + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print(q10.query(fn(customer), fn(nation), fn(lineitem), fn(orders))) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print(q10.query(fn(customer), fn(nation), fn(lineitem), fn(orders)).collect()) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print(q10.query(fn(customer), fn(nation), fn(lineitem), fn(orders))) diff --git a/tpch/execute/q11.py b/tpch/execute/q11.py new file mode 100644 index 000000000..a6b830f30 --- /dev/null +++ b/tpch/execute/q11.py @@ -0,0 +1,18 @@ +from queries import q11 + +from . import IO_FUNCS +from . import nation +from . import partsupp +from . import supplier + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print(q11.query(fn(nation), fn(partsupp), fn(supplier))) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print(q11.query(fn(nation), fn(partsupp), fn(supplier)).collect()) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print(q11.query(fn(nation), fn(partsupp), fn(supplier))) diff --git a/tpch/execute/q12.py b/tpch/execute/q12.py new file mode 100644 index 000000000..0cdc0378b --- /dev/null +++ b/tpch/execute/q12.py @@ -0,0 +1,17 @@ +from queries import q12 + +from . import IO_FUNCS +from . import line_item +from . import orders + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print(q12.query(fn(line_item), fn(orders))) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print(q12.query(fn(line_item), fn(orders)).collect()) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print(q12.query(fn(line_item), fn(orders))) diff --git a/tpch/execute/q13.py b/tpch/execute/q13.py new file mode 100644 index 000000000..b5e6c8bbe --- /dev/null +++ b/tpch/execute/q13.py @@ -0,0 +1,17 @@ +from queries import q13 + +from . import IO_FUNCS +from . import customer +from . import orders + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print(q13.query(fn(customer), fn(orders))) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print(q13.query(fn(customer), fn(orders)).collect()) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print(q13.query(fn(customer), fn(orders))) diff --git a/tpch/execute/q14.py b/tpch/execute/q14.py new file mode 100644 index 000000000..1a89dbbbe --- /dev/null +++ b/tpch/execute/q14.py @@ -0,0 +1,17 @@ +from queries import q14 + +from . import IO_FUNCS +from . import line_item +from . import part + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print(q14.query(fn(line_item), fn(part))) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print(q14.query(fn(line_item), fn(part)).collect()) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print(q14.query(fn(line_item), fn(part))) diff --git a/tpch/execute/q15.py b/tpch/execute/q15.py new file mode 100644 index 000000000..ac858841d --- /dev/null +++ b/tpch/execute/q15.py @@ -0,0 +1,17 @@ +from queries import q15 + +from . import IO_FUNCS +from . import lineitem +from . import supplier + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print(q15.query(fn(lineitem), fn(supplier))) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print(q15.query(fn(lineitem), fn(supplier)).collect()) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print(q15.query(fn(lineitem), fn(supplier))) diff --git a/tpch/execute/q16.py b/tpch/execute/q16.py new file mode 100644 index 000000000..7fa6c72b0 --- /dev/null +++ b/tpch/execute/q16.py @@ -0,0 +1,18 @@ +from queries import q16 + +from . import IO_FUNCS +from . import part +from . import partsupp +from . import supplier + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print(q16.query(fn(part), fn(partsupp), fn(supplier))) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print(q16.query(fn(part), fn(partsupp), fn(supplier)).collect()) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print(q16.query(fn(part), fn(partsupp), fn(supplier))) diff --git a/tpch/execute/q17.py b/tpch/execute/q17.py new file mode 100644 index 000000000..8eefb92dc --- /dev/null +++ b/tpch/execute/q17.py @@ -0,0 +1,17 @@ +from queries import q17 + +from . import IO_FUNCS +from . import lineitem +from . import part + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print(q17.query(fn(lineitem), fn(part))) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print(q17.query(fn(lineitem), fn(part)).collect()) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print(q17.query(fn(lineitem), fn(part))) diff --git a/tpch/execute/q18.py b/tpch/execute/q18.py new file mode 100644 index 000000000..fdd50c095 --- /dev/null +++ b/tpch/execute/q18.py @@ -0,0 +1,18 @@ +from queries import q18 + +from . import IO_FUNCS +from . import customer +from . import lineitem +from . import orders + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print(q18.query(fn(customer), fn(lineitem), fn(orders))) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print(q18.query(fn(customer), fn(lineitem), fn(orders)).collect()) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print(q18.query(fn(customer), fn(lineitem), fn(orders))) diff --git a/tpch/execute/q19.py b/tpch/execute/q19.py new file mode 100644 index 000000000..e1dff3eb5 --- /dev/null +++ b/tpch/execute/q19.py @@ -0,0 +1,14 @@ +from queries import q19 + +from . import IO_FUNCS +from . import lineitem +from . import part + +fn = IO_FUNCS["pandas[pyarrow]"] +print(q19.query(fn(lineitem), fn(part))) + +fn = IO_FUNCS["polars[lazy]"] +print(q19.query(fn(lineitem), fn(part)).collect()) + +fn = IO_FUNCS["pyarrow"] +print(q19.query(fn(lineitem), fn(part))) diff --git a/tpch/execute/q2.py b/tpch/execute/q2.py new file mode 100644 index 000000000..cd82a9047 --- /dev/null +++ b/tpch/execute/q2.py @@ -0,0 +1,53 @@ +from queries import q2 + +from . import IO_FUNCS +from . import nation +from . import part +from . import partsupp +from . import region +from . import supplier + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print( + q2.query( + fn(region), + fn(nation), + fn(supplier), + fn(part), + fn(partsupp), + ) +) +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print( + q2.query( + fn(region), + fn(nation), + fn(supplier), + fn(part), + fn(partsupp), + ).collect() +) +tool = "pyarrow" +fn = IO_FUNCS[tool] +print( + q2.query( + fn(region), + fn(nation), + fn(supplier), + fn(part), + fn(partsupp), + ) +) +tool = "dask" +fn = IO_FUNCS[tool] +print( + q2.query( + fn(region), + fn(nation), + fn(supplier), + fn(part), + fn(partsupp), + ).compute() +) diff --git a/tpch/execute/q20.py b/tpch/execute/q20.py new file mode 100644 index 000000000..d15f8c85f --- /dev/null +++ b/tpch/execute/q20.py @@ -0,0 +1,17 @@ +from queries import q20 + +from . import IO_FUNCS +from . import lineitem +from . import nation +from . import part +from . import partsupp +from . import supplier + +fn = IO_FUNCS["pandas[pyarrow]"] +print(q20.query(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(supplier))) + +fn = IO_FUNCS["polars[lazy]"] +print(q20.query(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(supplier)).collect()) + +fn = IO_FUNCS["pyarrow"] +print(q20.query(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(supplier))) diff --git a/tpch/execute/q21.py b/tpch/execute/q21.py new file mode 100644 index 000000000..9940e6232 --- /dev/null +++ b/tpch/execute/q21.py @@ -0,0 +1,16 @@ +from queries import q21 + +from . import IO_FUNCS +from . import lineitem +from . import nation +from . import orders +from . import supplier + +fn = IO_FUNCS["pandas[pyarrow]"] +print(q21.query(fn(lineitem), fn(nation), fn(orders), fn(supplier))) + +fn = IO_FUNCS["polars[lazy]"] +print(q21.query(fn(lineitem), fn(nation), fn(orders), fn(supplier)).collect()) + +fn = IO_FUNCS["pyarrow"] +print(q21.query(fn(lineitem), fn(nation), fn(orders), fn(supplier))) diff --git a/tpch/execute/q22.py b/tpch/execute/q22.py new file mode 100644 index 000000000..3b3fe523f --- /dev/null +++ b/tpch/execute/q22.py @@ -0,0 +1,17 @@ +from queries import q22 + +from . import IO_FUNCS +from . import customer +from . import orders + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print(q22.query(fn(customer), fn(orders))) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print(q22.query(fn(customer), fn(orders)).collect()) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print(q22.query(fn(customer), fn(orders))) diff --git a/tpch/execute/q3.py b/tpch/execute/q3.py new file mode 100644 index 000000000..f836fae27 --- /dev/null +++ b/tpch/execute/q3.py @@ -0,0 +1,18 @@ +from queries import q3 + +from . import IO_FUNCS +from . import customer +from . import lineitem +from . import orders + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print(q3.query(fn(customer), fn(lineitem), fn(orders))) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print(q3.query(fn(customer), fn(lineitem), fn(orders)).collect()) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print(q3.query(fn(customer), fn(lineitem), fn(orders))) diff --git a/tpch/execute/q4.py b/tpch/execute/q4.py new file mode 100644 index 000000000..ca60f38ee --- /dev/null +++ b/tpch/execute/q4.py @@ -0,0 +1,17 @@ +from queries import q4 + +from . import IO_FUNCS +from . import line_item +from . import orders + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print(q4.query(fn(line_item), fn(orders))) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print(q4.query(fn(line_item), fn(orders)).collect()) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print(q4.query(fn(line_item), fn(orders))) diff --git a/tpch/execute/q5.py b/tpch/execute/q5.py new file mode 100644 index 000000000..c343fea5d --- /dev/null +++ b/tpch/execute/q5.py @@ -0,0 +1,33 @@ +from queries import q5 + +from . import IO_FUNCS +from . import customer +from . import line_item +from . import nation +from . import orders +from . import region +from . import supplier + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print( + q5.query( + fn(region), fn(nation), fn(customer), fn(line_item), fn(orders), fn(supplier) + ) +) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print( + q5.query( + fn(region), fn(nation), fn(customer), fn(line_item), fn(orders), fn(supplier) + ).collect() +) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print( + q5.query( + fn(region), fn(nation), fn(customer), fn(line_item), fn(orders), fn(supplier) + ) +) diff --git a/tpch/execute/q6.py b/tpch/execute/q6.py new file mode 100644 index 000000000..eebf3f864 --- /dev/null +++ b/tpch/execute/q6.py @@ -0,0 +1,16 @@ +from queries import q6 + +from . import IO_FUNCS +from . import lineitem + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print(q6.query(fn(lineitem))) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print(q6.query(fn(lineitem)).collect()) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print(q6.query(fn(lineitem))) diff --git a/tpch/execute/q7.py b/tpch/execute/q7.py new file mode 100644 index 000000000..c59f82ce7 --- /dev/null +++ b/tpch/execute/q7.py @@ -0,0 +1,22 @@ +from queries import q7 + +from . import IO_FUNCS +from . import customer +from . import lineitem +from . import nation +from . import orders +from . import supplier + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print(q7.query(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print( + q7.query(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier)).collect() +) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print(q7.query(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))) diff --git a/tpch/execute/q8.py b/tpch/execute/q8.py new file mode 100644 index 000000000..902a34e70 --- /dev/null +++ b/tpch/execute/q8.py @@ -0,0 +1,53 @@ +from queries import q8 + +from . import IO_FUNCS +from . import customer +from . import lineitem +from . import nation +from . import orders +from . import part +from . import region +from . import supplier + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print( + q8.query( + fn(part), + fn(supplier), + fn(lineitem), + fn(orders), + fn(customer), + fn(nation), + fn(region), + ) +) + + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print( + q8.query( + fn(part), + fn(supplier), + fn(lineitem), + fn(orders), + fn(customer), + fn(nation), + fn(region), + ).collect() +) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print( + q8.query( + fn(part), + fn(supplier), + fn(lineitem), + fn(orders), + fn(customer), + fn(nation), + fn(region), + ) +) diff --git a/tpch/execute/q9.py b/tpch/execute/q9.py new file mode 100644 index 000000000..44d4154aa --- /dev/null +++ b/tpch/execute/q9.py @@ -0,0 +1,29 @@ +from queries import q9 + +from . import IO_FUNCS +from . import lineitem +from . import nation +from . import orders +from . import part +from . import partsupp +from . import supplier + +tool = "pandas[pyarrow]" +fn = IO_FUNCS[tool] +print( + q9.query(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(orders), fn(supplier)) +) + +tool = "polars[lazy]" +fn = IO_FUNCS[tool] +print( + q9.query( + fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(orders), fn(supplier) + ).collect() +) + +tool = "pyarrow" +fn = IO_FUNCS[tool] +print( + q9.query(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(orders), fn(supplier)) +) diff --git a/tpch/generate_data.py b/tpch/generate_data.py new file mode 100644 index 000000000..4d5695dcf --- /dev/null +++ b/tpch/generate_data.py @@ -0,0 +1,36 @@ +from pathlib import Path # noqa: INP001 + +import duckdb +import pyarrow as pa +import pyarrow.parquet as pq +import tqdm + +if not Path("data").exists(): + Path("data").mkdir() + +con = duckdb.connect(database=":memory:") +con.execute("INSTALL tpch; LOAD tpch") +con.execute("CALL dbgen(sf=1)") +tables = [ + "lineitem", + "customer", + "nation", + "orders", + "part", + "partsupp", + "region", + "supplier", +] +for t in tqdm.tqdm(tables): + res = con.query("SELECT * FROM " + t) # noqa: S608 + res_arrow = res.to_arrow_table() + new_schema = [] + for field in res_arrow.schema: + if isinstance(field.type, type(pa.decimal128(1))): + new_schema.append(pa.field(field.name, pa.float64())) + elif field.type == pa.date32(): + new_schema.append(pa.field(field.name, pa.timestamp("ns"))) + else: + new_schema.append(field) + res_arrow = res_arrow.cast(pa.schema(new_schema)) + pq.write_table(res_arrow, Path("data") / f"{t}.parquet") diff --git a/tpch/notebooks/gpu/execute.ipynb b/tpch/notebooks/gpu/execute.ipynb deleted file mode 100755 index a117c9187..000000000 --- a/tpch/notebooks/gpu/execute.ipynb +++ /dev/null @@ -1,616 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "0", - "metadata": { - "papermill": { - "duration": 13.50746, - "end_time": "2024-03-23T16:07:02.428026", - "exception": false, - "start_time": "2024-03-23T16:06:48.920566", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "!pip install -U narwhals>=0.7.2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1", - "metadata": { - "papermill": { - "duration": 276.933638, - "end_time": "2024-03-23T16:11:39.366857", - "exception": false, - "start_time": "2024-03-23T16:07:02.433219", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# Remove all conda packages\n", - "!find /opt/conda \\( -name \"cudf*\" -o -name \"libcudf*\" -o -name \"cuml*\" -o -name \"libcuml*\" \\\n", - " -o -name \"cugraph*\" -o -name \"libcugraph*\" -o -name \"raft*\" -o -name \"libraft*\" \\\n", - " -o -name \"pylibraft*\" -o -name \"libkvikio*\" -o -name \"*dask*\" -o -name \"rmm*\"\\\n", - " -o -name \"librmm*\" \\) -exec rm -rf {} \\; 2>/dev/null\n", - "\n", - "# pip uninstall, just incase there are packages lying around\n", - "!pip uninstall cudf cuml dask-cudf cuml cugraph cupy cupy-cuda12x --y\n", - "\n", - "\n", - "!pip install \\\n", - " --extra-index-url=https://pypi.nvidia.com \\\n", - " cudf-cu12==24.2.* \\\n", - " dask-cudf-cu12==24.2.* \\\n", - " cuml-cu12==24.2.* \\\n", - " cugraph-cu12==24.2.*\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2", - "metadata": { - "papermill": { - "duration": 4.147764, - "end_time": "2024-03-23T16:11:43.699466", - "exception": false, - "start_time": "2024-03-23T16:11:39.551702", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "import cudf" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3", - "metadata": { - "papermill": { - "duration": 0.148115, - "end_time": "2024-03-23T16:11:44.003370", - "exception": false, - "start_time": "2024-03-23T16:11:43.855255", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4", - "metadata": { - "papermill": { - "duration": 0.186887, - "end_time": "2024-03-23T16:11:44.341982", - "exception": false, - "start_time": "2024-03-23T16:11:44.155095", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from typing import Any\n", - "from datetime import datetime\n", - "import narwhals as nw\n", - "\n", - "def q1(df_raw: Any) -> Any:\n", - " var_1 = datetime(1998, 9, 2)\n", - " df = nw.from_native(df_raw)\n", - " result = (\n", - " df.filter(nw.col(\"l_shipdate\") <= var_1)\n", - " .with_columns(\n", - " disc_price=nw.col(\"l_extendedprice\") * (1 - nw.col(\"l_discount\")),\n", - " charge=(\n", - " nw.col(\"l_extendedprice\")\n", - " * (1.0 - nw.col(\"l_discount\"))\n", - " * (1.0 + nw.col(\"l_tax\"))\n", - " ),\n", - " )\n", - " .group_by([\"l_returnflag\", \"l_linestatus\"])\n", - " .agg(\n", - " [\n", - " nw.sum(\"l_quantity\").alias(\"sum_qty\"),\n", - " nw.sum(\"l_extendedprice\").alias(\"sum_base_price\"),\n", - " nw.sum(\"disc_price\").alias(\"sum_disc_price\"),\n", - " nw.col(\"charge\").sum().alias(\"sum_charge\"),\n", - " nw.mean(\"l_quantity\").alias(\"avg_qty\"),\n", - " nw.mean(\"l_extendedprice\").alias(\"avg_price\"),\n", - " nw.mean(\"l_discount\").alias(\"avg_disc\"),\n", - " nw.len().alias(\"count_order\"),\n", - " ],\n", - " )\n", - " .sort([\"l_returnflag\", \"l_linestatus\"])\n", - " )\n", - " return nw.to_native(result)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5", - "metadata": { - "papermill": { - "duration": 0.158162, - "end_time": "2024-03-23T16:11:44.671154", - "exception": false, - "start_time": "2024-03-23T16:11:44.512992", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "\n", - "\n", - "from typing import Any\n", - "from datetime import datetime\n", - "import narwhals as nw\n", - "\n", - "def q2(\n", - " region_ds_raw: Any,\n", - " nation_ds_raw: Any,\n", - " supplier_ds_raw: Any,\n", - " part_ds_raw: Any,\n", - " part_supp_ds_raw: Any,\n", - ") -> Any:\n", - " var_1 = 15\n", - " var_2 = \"BRASS\"\n", - " var_3 = \"EUROPE\"\n", - "\n", - " region_ds = nw.from_native(region_ds_raw)\n", - " nation_ds = nw.from_native(nation_ds_raw)\n", - " supplier_ds = nw.from_native(supplier_ds_raw)\n", - " part_ds = nw.from_native(part_ds_raw)\n", - " part_supp_ds = nw.from_native(part_supp_ds_raw)\n", - "\n", - " result_q2 = (\n", - " part_ds.join(part_supp_ds, left_on=\"p_partkey\", right_on=\"ps_partkey\")\n", - " .join(supplier_ds, left_on=\"ps_suppkey\", right_on=\"s_suppkey\")\n", - " .join(nation_ds, left_on=\"s_nationkey\", right_on=\"n_nationkey\")\n", - " .join(region_ds, left_on=\"n_regionkey\", right_on=\"r_regionkey\")\n", - " .filter(nw.col(\"p_size\") == var_1)\n", - " .filter(nw.col(\"p_type\").str.ends_with(var_2))\n", - " .filter(nw.col(\"r_name\") == var_3)\n", - " )\n", - "\n", - " final_cols = [\n", - " \"s_acctbal\",\n", - " \"s_name\",\n", - " \"n_name\",\n", - " \"p_partkey\",\n", - " \"p_mfgr\",\n", - " \"s_address\",\n", - " \"s_phone\",\n", - " \"s_comment\",\n", - " ]\n", - "\n", - " q_final = (\n", - " result_q2.group_by(\"p_partkey\")\n", - " .agg(nw.min(\"ps_supplycost\").alias(\"ps_supplycost\"))\n", - " .join(\n", - " result_q2,\n", - " left_on=[\"p_partkey\", \"ps_supplycost\"],\n", - " right_on=[\"p_partkey\", \"ps_supplycost\"],\n", - " )\n", - " .select(final_cols)\n", - " .sort(\n", - " by=[\"s_acctbal\", \"n_name\", \"s_name\", \"p_partkey\"],\n", - " descending=[True, False, False, False],\n", - " )\n", - " .head(100)\n", - " )\n", - "\n", - " return nw.to_native(q_final)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6", - "metadata": { - "papermill": { - "duration": 0.153978, - "end_time": "2024-03-23T16:11:44.967693", - "exception": false, - "start_time": "2024-03-23T16:11:44.813715", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from typing import Any\n", - "from datetime import datetime\n", - "import narwhals as nw\n", - "\n", - "def q3(\n", - " customer_ds_raw: Any,\n", - " line_item_ds_raw: Any,\n", - " orders_ds_raw: Any,\n", - ") -> Any:\n", - " var_1 = var_2 = datetime(1995, 3, 15)\n", - " var_3 = \"BUILDING\"\n", - "\n", - " customer_ds = nw.from_native(customer_ds_raw)\n", - " line_item_ds = nw.from_native(line_item_ds_raw)\n", - " orders_ds = nw.from_native(orders_ds_raw)\n", - "\n", - " q_final = (\n", - " customer_ds.filter(nw.col(\"c_mktsegment\") == var_3)\n", - " .join(orders_ds, left_on=\"c_custkey\", right_on=\"o_custkey\")\n", - " .join(line_item_ds, left_on=\"o_orderkey\", right_on=\"l_orderkey\")\n", - " .filter(nw.col(\"o_orderdate\") < var_2)\n", - " .filter(nw.col(\"l_shipdate\") > var_1)\n", - " .with_columns(\n", - " (nw.col(\"l_extendedprice\") * (1 - nw.col(\"l_discount\"))).alias(\"revenue\")\n", - " )\n", - " .group_by([\"o_orderkey\", \"o_orderdate\", \"o_shippriority\"])\n", - " .agg([nw.sum(\"revenue\")])\n", - " .select(\n", - " [\n", - " nw.col(\"o_orderkey\").alias(\"l_orderkey\"),\n", - " \"revenue\",\n", - " \"o_orderdate\",\n", - " \"o_shippriority\",\n", - " ]\n", - " )\n", - " .sort(by=[\"revenue\", \"o_orderdate\"], descending=[True, False])\n", - " .head(10)\n", - " )\n", - "\n", - " return nw.to_native(q_final)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7", - "metadata": { - "papermill": { - "duration": 0.153799, - "end_time": "2024-03-23T16:11:45.265846", - "exception": false, - "start_time": "2024-03-23T16:11:45.112047", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "\n", - "\n", - "from typing import Any\n", - "from datetime import datetime\n", - "import narwhals as nw\n", - "\n", - "def q4(\n", - " lineitem_ds_raw: Any,\n", - " orders_ds_raw: Any,\n", - ") -> Any:\n", - " var_1 = datetime(1993, 7, 1)\n", - " var_2 = datetime(1993, 10, 1)\n", - "\n", - " line_item_ds = nw.from_native(lineitem_ds_raw)\n", - " orders_ds = nw.from_native(orders_ds_raw)\n", - "\n", - " result = (\n", - " line_item_ds.join(orders_ds, left_on=\"l_orderkey\", right_on=\"o_orderkey\")\n", - " .filter(nw.col(\"o_orderdate\").is_between(var_1, var_2, closed=\"left\"))\n", - " .filter(nw.col(\"l_commitdate\") < nw.col(\"l_receiptdate\"))\n", - " .unique(subset=[\"o_orderpriority\", \"l_orderkey\"])\n", - " .group_by(\"o_orderpriority\")\n", - " .agg(nw.len().alias(\"order_count\"))\n", - " .sort(by=\"o_orderpriority\")\n", - " .with_columns(nw.col(\"order_count\").cast(nw.Int64))\n", - " )\n", - "\n", - " return nw.to_native(result)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8", - "metadata": { - "papermill": { - "duration": 0.159087, - "end_time": "2024-03-23T16:11:45.567586", - "exception": false, - "start_time": "2024-03-23T16:11:45.408499", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from typing import Any\n", - "from datetime import datetime\n", - "import narwhals as nw\n", - "\n", - "def q5(\n", - " region_ds_raw: Any,\n", - " nation_ds_raw: Any,\n", - " customer_ds_raw: Any,\n", - " lineitem_ds_raw: Any,\n", - " orders_ds_raw: Any,\n", - " supplier_ds_raw: Any,\n", - ") -> Any:\n", - " var_1 = \"ASIA\"\n", - " var_2 = datetime(1994, 1, 1)\n", - " var_3 = datetime(1995, 1, 1)\n", - "\n", - " region_ds = nw.from_native(region_ds_raw)\n", - " nation_ds = nw.from_native(nation_ds_raw)\n", - " customer_ds = nw.from_native(customer_ds_raw)\n", - " line_item_ds = nw.from_native(lineitem_ds_raw)\n", - " orders_ds = nw.from_native(orders_ds_raw)\n", - " supplier_ds = nw.from_native(supplier_ds_raw)\n", - "\n", - " result = (\n", - " region_ds.join(nation_ds, left_on=\"r_regionkey\", right_on=\"n_regionkey\")\n", - " .join(customer_ds, left_on=\"n_nationkey\", right_on=\"c_nationkey\")\n", - " .join(orders_ds, left_on=\"c_custkey\", right_on=\"o_custkey\")\n", - " .join(line_item_ds, left_on=\"o_orderkey\", right_on=\"l_orderkey\")\n", - " .join(\n", - " supplier_ds,\n", - " left_on=[\"l_suppkey\", \"n_nationkey\"],\n", - " right_on=[\"s_suppkey\", \"s_nationkey\"],\n", - " )\n", - " .filter(\n", - " nw.col(\"r_name\") == var_1,\n", - " nw.col(\"o_orderdate\").is_between(var_2, var_3, closed=\"left\")\n", - " )\n", - " .with_columns(\n", - " (nw.col(\"l_extendedprice\") * (1 - nw.col(\"l_discount\"))).alias(\"revenue\")\n", - " )\n", - " .group_by(\"n_name\")\n", - " .agg([nw.sum(\"revenue\")])\n", - " .sort(by=\"revenue\", descending=True)\n", - " )\n", - "\n", - " return nw.to_native(result)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9", - "metadata": { - "papermill": { - "duration": 0.154065, - "end_time": "2024-03-23T16:11:45.865878", - "exception": false, - "start_time": "2024-03-23T16:11:45.711813", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "region = dir_ + 'region.parquet'\n", - "nation = dir_ + 'nation.parquet'\n", - "customer = dir_ + 'customer.parquet'\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "orders = dir_ + 'orders.parquet'\n", - "supplier = dir_ + 'supplier.parquet'\n", - "part = dir_ + 'part.parquet'\n", - "partsupp = dir_ + 'partsupp.parquet'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10", - "metadata": { - "papermill": { - "duration": 0.149554, - "end_time": "2024-03-23T16:11:46.157770", - "exception": false, - "start_time": "2024-03-23T16:11:46.008216", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "results = {}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "11", - "metadata": { - "papermill": { - "duration": 112.394193, - "end_time": "2024-03-23T16:13:38.694589", - "exception": false, - "start_time": "2024-03-23T16:11:46.300396", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "import cudf\n", - "fn = cudf.read_parquet\n", - "timings = %timeit -o -q q1(fn(lineitem))\n", - "results['q1'] = timings.all_runs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12", - "metadata": { - "papermill": { - "duration": 7.722805, - "end_time": "2024-03-23T16:13:46.572524", - "exception": false, - "start_time": "2024-03-23T16:13:38.849719", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "import cudf\n", - "fn = cudf.read_parquet\n", - "timings = %timeit -o -q q2(fn(region), fn(nation), fn(supplier), fn(part), fn(partsupp))\n", - "results['q2'] = timings.all_runs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "13", - "metadata": { - "papermill": { - "duration": 8.745438, - "end_time": "2024-03-23T16:13:55.465888", - "exception": false, - "start_time": "2024-03-23T16:13:46.720450", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "import cudf\n", - "fn = cudf.read_parquet\n", - "timings = %timeit -o -q q3(fn(customer), fn(lineitem), fn(orders))\n", - "results['q3'] = timings.all_runs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "14", - "metadata": { - "papermill": { - "duration": 9.273157, - "end_time": "2024-03-23T16:14:04.904258", - "exception": false, - "start_time": "2024-03-23T16:13:55.631101", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "import cudf\n", - "fn = cudf.read_parquet\n", - "timings = %timeit -o -q q4(fn(lineitem), fn(orders))\n", - "results['q4'] = timings.all_runs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "15", - "metadata": { - "papermill": { - "duration": 11.308626, - "end_time": "2024-03-23T16:14:16.369678", - "exception": false, - "start_time": "2024-03-23T16:14:05.061052", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "import cudf\n", - "fn = cudf.read_parquet\n", - "timings = %timeit -o -q q5(fn(region), fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))\n", - "results['q5'] = timings.all_runs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16", - "metadata": { - "papermill": { - "duration": 0.157455, - "end_time": "2024-03-23T16:14:16.689793", - "exception": false, - "start_time": "2024-03-23T16:14:16.532338", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)" - ] - } - ], - "metadata": { - "kaggle": { - "accelerator": "nvidiaTeslaT4", - "dataSources": [ - { - "sourceId": 167796934, - "sourceType": "kernelVersion" - } - ], - "dockerImageVersionId": 30674, - "isGpuEnabled": true, - "isInternetEnabled": true, - "language": "python", - "sourceType": "notebook" - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - }, - "papermill": { - "default_parameters": {}, - "duration": 451.646729, - "end_time": "2024-03-23T16:14:17.553113", - "environment_variables": {}, - "exception": null, - "input_path": "__notebook__.ipynb", - "output_path": "__notebook__.ipynb", - "parameters": {}, - "start_time": "2024-03-23T16:06:45.906384", - "version": "2.5.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tpch/notebooks/gpu/kernel-metadata.json b/tpch/notebooks/gpu/kernel-metadata.json deleted file mode 100644 index 128aea001..000000000 --- a/tpch/notebooks/gpu/kernel-metadata.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "id": "marcogorelli/narwhals-tpch-gpu-s2", - "title": "Narwhals TPCH GPU S2", - "code_file": "execute.ipynb", - "language": "python", - "kernel_type": "notebook", - "is_private": "false", - "enable_gpu": "true", - "enable_tpu": "false", - "enable_internet": "true", - "dataset_sources": [], - "competition_sources": [], - "kernel_sources": ["marcogorelli/tpc-h-data-parquet-s-2"], - "model_sources": [] -} \ No newline at end of file diff --git a/tpch/notebooks/q1/execute.ipynb b/tpch/notebooks/q1/execute.ipynb index cc6dd4559..de9c52baa 100755 --- a/tpch/notebooks/q1/execute.ipynb +++ b/tpch/notebooks/q1/execute.ipynb @@ -58,10 +58,12 @@ }, "outputs": [], "source": [ - "from typing import Any\n", "from datetime import datetime\n", + "from typing import Any\n", + "\n", "import narwhals as nw\n", "\n", + "\n", "@nw.narwhalify\n", "def q1(lineitem_ds: Any) -> Any:\n", " var_1 = datetime(1998, 9, 2)\n", @@ -107,14 +109,14 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "region = dir_ + 'region.parquet'\n", - "nation = dir_ + 'nation.parquet'\n", - "customer = dir_ + 'customer.parquet'\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "orders = dir_ + 'orders.parquet'\n", - "supplier = dir_ + 'supplier.parquet'\n", - "part = dir_ + 'part.parquet'\n", - "partsupp = dir_ + 'partsupp.parquet'" + "region = dir_ + \"region.parquet\"\n", + "nation = dir_ + \"nation.parquet\"\n", + "customer = dir_ + \"customer.parquet\"\n", + "lineitem = dir_ + \"lineitem.parquet\"\n", + "orders = dir_ + \"orders.parquet\"\n", + "supplier = dir_ + \"supplier.parquet\"\n", + "part = dir_ + \"part.parquet\"\n", + "partsupp = dir_ + \"partsupp.parquet\"" ] }, { @@ -133,16 +135,18 @@ }, "outputs": [], "source": [ - "import pyarrow.parquet as pq\n", "import dask.dataframe as dd\n", + "import pyarrow.parquet as pq\n", "\n", "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", - " 'pyarrow': lambda x: pq.read_table(x),\n", - " 'dask': lambda x: dd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", + " \"pyarrow\": lambda x: pq.read_table(x),\n", + " \"dask\": lambda x: dd.read_parquet(x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"),\n", "}" ] }, @@ -171,7 +175,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'pyarrow'\n", + "tool = \"pyarrow\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q1(fn(lineitem))\n", "results[tool] = timings.all_runs" @@ -210,7 +214,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q1(lineitem_ds=fn(lineitem))\n", "results[tool] = timings.all_runs" @@ -249,7 +253,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q1(fn(lineitem))\n", "results[tool] = timings.all_runs" @@ -288,7 +292,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q1(fn(lineitem))\n", "results[tool] = timings.all_runs" @@ -327,7 +331,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q1(fn(lineitem)).collect()\n", "results[tool] = timings.all_runs" @@ -348,7 +352,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'dask'\n", + "tool = \"dask\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q1(fn(lineitem)).collect()\n", "results[tool] = timings.all_runs" @@ -370,8 +374,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/notebooks/q10/execute.ipynb b/tpch/notebooks/q10/execute.ipynb index 85ec0f14b..9ff211773 100644 --- a/tpch/notebooks/q10/execute.ipynb +++ b/tpch/notebooks/q10/execute.ipynb @@ -55,22 +55,23 @@ }, "outputs": [], "source": [ - "from typing import Any\n", "from datetime import datetime\n", + "from typing import Any\n", + "\n", "import narwhals as nw\n", "\n", + "\n", "def q10(\n", " customer_ds_raw: Any,\n", " nation_ds_raw: Any,\n", " lineitem_ds_raw: Any,\n", " orders_ds_raw: Any,\n", ") -> Any:\n", - "\n", " nation_ds = nw.from_native(nation_ds_raw)\n", " line_item_ds = nw.from_native(lineitem_ds_raw)\n", " orders_ds = nw.from_native(orders_ds_raw)\n", " customer_ds = nw.from_native(customer_ds_raw)\n", - " \n", + "\n", " var1 = datetime(1993, 10, 1)\n", " var2 = datetime(1994, 1, 1)\n", "\n", @@ -81,8 +82,7 @@ " .filter(nw.col(\"o_orderdate\").is_between(var1, var2, closed=\"left\"))\n", " .filter(nw.col(\"l_returnflag\") == \"R\")\n", " .with_columns(\n", - " (nw.col(\"l_extendedprice\") * (1 - nw.col(\"l_discount\")))\n", - " .alias(\"revenue\")\n", + " (nw.col(\"l_extendedprice\") * (1 - nw.col(\"l_discount\"))).alias(\"revenue\")\n", " )\n", " .group_by(\n", " \"c_custkey\",\n", @@ -127,10 +127,10 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "nation = dir_ + 'nation.parquet'\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "orders = dir_ + 'orders.parquet'\n", - "customer = dir_ + 'customer.parquet'" + "nation = dir_ + \"nation.parquet\"\n", + "lineitem = dir_ + \"lineitem.parquet\"\n", + "orders = dir_ + \"orders.parquet\"\n", + "customer = dir_ + \"customer.parquet\"" ] }, { @@ -149,10 +149,12 @@ "outputs": [], "source": [ "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", "}" ] }, @@ -196,7 +198,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q10(fn(customer), fn(nation), fn(lineitem), fn(orders))\n", "results[tool] = timings.all_runs" @@ -233,7 +235,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q10(fn(customer), fn(nation), fn(lineitem), fn(orders))\n", "results[tool] = timings.all_runs" @@ -270,7 +272,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q10(fn(customer), fn(nation), fn(lineitem), fn(orders))\n", "results[tool] = timings.all_runs" @@ -307,7 +309,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q10(fn(customer), fn(nation), fn(lineitem), fn(orders)).collect()\n", "results[tool] = timings.all_runs" @@ -327,8 +329,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/notebooks/q11/execute.ipynb b/tpch/notebooks/q11/execute.ipynb index 33951d922..f5bbc0f9c 100644 --- a/tpch/notebooks/q11/execute.ipynb +++ b/tpch/notebooks/q11/execute.ipynb @@ -15,7 +15,7 @@ }, "outputs": [], "source": [ - "!pip uninstall apache-beam -y && pip install -U pandas polars pyarrow narwhals " + "!pip uninstall apache-beam -y && pip install -U pandas polars pyarrow narwhals" ] }, { @@ -56,19 +56,19 @@ "outputs": [], "source": [ "from typing import Any\n", + "\n", "import narwhals as nw\n", "\n", + "\n", "def q11(\n", " partsupp_ds_raw: Any,\n", " nation_ds_raw: Any,\n", " supplier_ds_raw: Any,\n", ") -> Any:\n", - "\n", " nation_ds = nw.from_native(nation_ds_raw)\n", " partsupp_ds = nw.from_native(partsupp_ds_raw)\n", " supplier_ds = nw.from_native(supplier_ds_raw)\n", "\n", - " \n", " var1 = \"GERMANY\"\n", " var2 = 0.0001\n", "\n", @@ -83,14 +83,9 @@ " )\n", "\n", " q_final = (\n", - " q1.with_columns(\n", - " (nw.col(\"ps_supplycost\") * nw.col(\"ps_availqty\"))\n", - " .alias(\"value\")\n", - " )\n", + " q1.with_columns((nw.col(\"ps_supplycost\") * nw.col(\"ps_availqty\")).alias(\"value\"))\n", " .group_by(\"ps_partkey\")\n", - " .agg(\n", - " nw.sum(\"value\")\n", - " )\n", + " .agg(nw.sum(\"value\"))\n", " .join(q2, how=\"cross\")\n", " .filter(nw.col(\"value\") > nw.col(\"tmp\"))\n", " .select(\"ps_partkey\", \"value\")\n", @@ -116,9 +111,9 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "nation = dir_ + 'nation.parquet'\n", - "supplier = dir_ + 'supplier.parquet'\n", - "partsupp = dir_ + 'partsupp.parquet'" + "nation = dir_ + \"nation.parquet\"\n", + "supplier = dir_ + \"supplier.parquet\"\n", + "partsupp = dir_ + \"partsupp.parquet\"" ] }, { @@ -137,10 +132,12 @@ "outputs": [], "source": [ "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", "}" ] }, @@ -184,7 +181,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q11(fn(partsupp), fn(nation), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -221,7 +218,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q11(fn(partsupp), fn(nation), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -258,7 +255,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q11(fn(partsupp), fn(nation), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -295,7 +292,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q11(fn(partsupp), fn(nation), fn(supplier)).collect()\n", "results[tool] = timings.all_runs" @@ -315,8 +312,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] }, { diff --git a/tpch/notebooks/q15/execute.ipynb b/tpch/notebooks/q15/execute.ipynb index 0baf11956..d108a7196 100644 --- a/tpch/notebooks/q15/execute.ipynb +++ b/tpch/notebooks/q15/execute.ipynb @@ -15,7 +15,7 @@ }, "outputs": [], "source": [ - "!pip uninstall apache-beam -y && pip install -U pandas polars pyarrow narwhals " + "!pip uninstall apache-beam -y && pip install -U pandas polars pyarrow narwhals" ] }, { @@ -55,32 +55,34 @@ }, "outputs": [], "source": [ + "from datetime import datetime\n", "from typing import Any\n", + "\n", "import narwhals as nw\n", - "from datetime import datetime\n", + "\n", "\n", "def q15(\n", " lineitem_ds_raw: Any,\n", " supplier_ds_raw: Any,\n", ") -> Any:\n", - "\n", " lineitem_ds = nw.from_native(lineitem_ds_raw)\n", " supplier_ds = nw.from_native(supplier_ds_raw)\n", - " \n", + "\n", " var1 = datetime(1996, 1, 1)\n", " var2 = datetime(1996, 4, 1)\n", "\n", " revenue = (\n", " lineitem_ds.filter(nw.col(\"l_shipdate\").is_between(var1, var2, closed=\"left\"))\n", " .with_columns(\n", - " (nw.col(\"l_extendedprice\") * (1 - nw.col(\"l_discount\")))\n", - " .alias(\"total_revenue\")\n", + " (nw.col(\"l_extendedprice\") * (1 - nw.col(\"l_discount\"))).alias(\n", + " \"total_revenue\"\n", + " )\n", " )\n", " .group_by(\"l_suppkey\")\n", " .agg(nw.sum(\"total_revenue\"))\n", " .select(nw.col(\"l_suppkey\").alias(\"supplier_no\"), nw.col(\"total_revenue\"))\n", " )\n", - " \n", + "\n", " result = (\n", " supplier_ds.join(revenue, left_on=\"s_suppkey\", right_on=\"supplier_no\")\n", " .filter(nw.col(\"total_revenue\") == nw.col(\"total_revenue\").max())\n", @@ -108,8 +110,8 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "supplier = dir_ + 'supplier.parquet'" + "lineitem = dir_ + \"lineitem.parquet\"\n", + "supplier = dir_ + \"supplier.parquet\"" ] }, { @@ -128,10 +130,12 @@ "outputs": [], "source": [ "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", "}" ] }, @@ -175,7 +179,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q15(fn(lineitem), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -212,7 +216,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q15(fn(lineitem), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -249,7 +253,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q15(fn(lineitem), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -286,7 +290,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q15(fn(lineitem), fn(supplier)).collect()\n", "results[tool] = timings.all_runs" @@ -306,8 +310,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/notebooks/q17/execute.ipynb b/tpch/notebooks/q17/execute.ipynb index b13445d28..4d012f088 100644 --- a/tpch/notebooks/q17/execute.ipynb +++ b/tpch/notebooks/q17/execute.ipynb @@ -15,7 +15,7 @@ }, "outputs": [], "source": [ - "!pip uninstall apache-beam -y && pip install -U pandas polars pyarrow narwhals " + "!pip uninstall apache-beam -y && pip install -U pandas polars pyarrow narwhals" ] }, { @@ -56,25 +56,23 @@ "outputs": [], "source": [ "from typing import Any\n", + "\n", "import narwhals as nw\n", "\n", - "def q17(\n", - " lineitem_ds_raw: Any,\n", - " part_ds_raw: Any\n", - ") -> Any:\n", "\n", + "def q17(lineitem_ds_raw: Any, part_ds_raw: Any) -> Any:\n", " lineitem_ds = nw.from_native(lineitem_ds_raw)\n", " part_ds = nw.from_native(part_ds_raw)\n", - " \n", + "\n", " var1 = \"Brand#23\"\n", " var2 = \"MED BOX\"\n", - " \n", + "\n", " query1 = (\n", " part_ds.filter(nw.col(\"p_brand\") == var1)\n", " .filter(nw.col(\"p_container\") == var2)\n", " .join(lineitem_ds, how=\"left\", left_on=\"p_partkey\", right_on=\"l_partkey\")\n", " )\n", - " \n", + "\n", " final_query = (\n", " query1.group_by(\"p_partkey\")\n", " .agg((0.2 * nw.col(\"l_quantity\").mean()).alias(\"avg_quantity\"))\n", @@ -84,7 +82,6 @@ " .select((nw.col(\"l_extendedprice\").sum() / 7.0).round(2).alias(\"avg_yearly\"))\n", " )\n", "\n", - "\n", " return nw.to_native(final_query)" ] }, @@ -104,8 +101,8 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "part = dir_ + 'part.parquet'" + "lineitem = dir_ + \"lineitem.parquet\"\n", + "part = dir_ + \"part.parquet\"" ] }, { @@ -124,10 +121,12 @@ "outputs": [], "source": [ "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", "}" ] }, @@ -171,7 +170,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q17(fn(lineitem), fn(part))\n", "results[tool] = timings.all_runs" @@ -208,7 +207,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q17(fn(lineitem), fn(part))\n", "results[tool] = timings.all_runs" @@ -245,7 +244,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q17(fn(lineitem), fn(part))\n", "results[tool] = timings.all_runs" @@ -282,7 +281,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q17(fn(lineitem), fn(part)).collect()\n", "results[tool] = timings.all_runs" @@ -302,8 +301,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/notebooks/q18/execute.ipynb b/tpch/notebooks/q18/execute.ipynb index c90629e0f..edf635d9e 100644 --- a/tpch/notebooks/q18/execute.ipynb +++ b/tpch/notebooks/q18/execute.ipynb @@ -6,7 +6,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip uninstall apache-beam -y && pip install -U pandas polars pyarrow narwhals " + "!pip uninstall apache-beam -y && pip install -U pandas polars pyarrow narwhals" ] }, { @@ -29,18 +29,15 @@ "outputs": [], "source": [ "from typing import Any\n", + "\n", "import narwhals as nw\n", "\n", - "def q18(\n", - " customer_ds_raw: Any,\n", - " lineitem_ds_raw: Any,\n", - " orders_ds_raw: Any\n", - ") -> Any:\n", "\n", + "def q18(customer_ds_raw: Any, lineitem_ds_raw: Any, orders_ds_raw: Any) -> Any:\n", " customer_ds = nw.from_native(customer_ds_raw)\n", " lineitem_ds = nw.from_native(lineitem_ds_raw)\n", " orders_ds = nw.from_native(orders_ds_raw)\n", - " \n", + "\n", " var1 = 300\n", "\n", " query1 = (\n", @@ -67,7 +64,6 @@ " .head(100)\n", " )\n", "\n", - "\n", " return nw.to_native(q_final)" ] }, @@ -78,9 +74,9 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "customer = dir_ + 'customer.parquet'\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "orders = dir_ + 'orders.parquet'" + "customer = dir_ + \"customer.parquet\"\n", + "lineitem = dir_ + \"lineitem.parquet\"\n", + "orders = dir_ + \"orders.parquet\"" ] }, { @@ -90,10 +86,12 @@ "outputs": [], "source": [ "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", "}" ] }, @@ -119,7 +117,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q19(fn(lineitem), fn(part))\n", "results[tool] = timings.all_runs" @@ -138,7 +136,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q18(fn(customer), fn(lineitem), fn(orders))\n", "results[tool] = timings.all_runs" @@ -157,7 +155,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q18(fn(customer), fn(lineitem), fn(orders))\n", "results[tool] = timings.all_runs" @@ -176,7 +174,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q18(fn(customer), fn(lineitem), fn(orders)).collect()\n", "results[tool] = timings.all_runs" @@ -196,8 +194,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/notebooks/q19/execute.ipynb b/tpch/notebooks/q19/execute.ipynb index 8483e06d5..8860cc773 100644 --- a/tpch/notebooks/q19/execute.ipynb +++ b/tpch/notebooks/q19/execute.ipynb @@ -15,7 +15,7 @@ }, "outputs": [], "source": [ - "!pip uninstall apache-beam -y && pip install -U pandas polars pyarrow narwhals " + "!pip uninstall apache-beam -y && pip install -U pandas polars pyarrow narwhals" ] }, { @@ -56,14 +56,11 @@ "outputs": [], "source": [ "from typing import Any\n", + "\n", "import narwhals as nw\n", "\n", - "def q19(\n", - " lineitem_ds_raw: Any,\n", - " part_ds_raw: Any\n", - " \n", - ") -> Any:\n", "\n", + "def q19(lineitem_ds_raw: Any, part_ds_raw: Any) -> Any:\n", " lineitem_ds = nw.from_native(lineitem_ds_raw)\n", " part_ds = nw.from_native(part_ds_raw)\n", "\n", @@ -74,9 +71,7 @@ " .filter(\n", " (\n", " (nw.col(\"p_brand\") == \"Brand#12\")\n", - " & nw.col(\"p_container\").is_in(\n", - " [\"SM CASE\", \"SM BOX\", \"SM PACK\", \"SM PKG\"]\n", - " )\n", + " & nw.col(\"p_container\").is_in([\"SM CASE\", \"SM BOX\", \"SM PACK\", \"SM PKG\"])\n", " & (nw.col(\"l_quantity\").is_between(1, 11))\n", " & (nw.col(\"p_size\").is_between(1, 5))\n", " )\n", @@ -90,9 +85,7 @@ " )\n", " | (\n", " (nw.col(\"p_brand\") == \"Brand#34\")\n", - " & nw.col(\"p_container\").is_in(\n", - " [\"LG CASE\", \"LG BOX\", \"LG PACK\", \"LG PKG\"]\n", - " )\n", + " & nw.col(\"p_container\").is_in([\"LG CASE\", \"LG BOX\", \"LG PACK\", \"LG PKG\"])\n", " & (nw.col(\"l_quantity\").is_between(20, 30))\n", " & (nw.col(\"p_size\").is_between(1, 15))\n", " )\n", @@ -105,7 +98,6 @@ " )\n", " )\n", "\n", - "\n", " return nw.to_native(result)" ] }, @@ -125,8 +117,8 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "part = dir_ + 'part.parquet'" + "lineitem = dir_ + \"lineitem.parquet\"\n", + "part = dir_ + \"part.parquet\"" ] }, { @@ -145,10 +137,12 @@ "outputs": [], "source": [ "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", "}" ] }, @@ -192,7 +186,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q19(fn(lineitem), fn(part))\n", "results[tool] = timings.all_runs" @@ -229,7 +223,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q19(fn(lineitem), fn(part))\n", "results[tool] = timings.all_runs" @@ -266,7 +260,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q19(fn(lineitem), fn(part))\n", "results[tool] = timings.all_runs" @@ -303,7 +297,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q19(fn(lineitem), fn(part)).collect()\n", "results[tool] = timings.all_runs" @@ -323,8 +317,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/notebooks/q2/execute.ipynb b/tpch/notebooks/q2/execute.ipynb index c05345336..74ba50f2a 100755 --- a/tpch/notebooks/q2/execute.ipynb +++ b/tpch/notebooks/q2/execute.ipynb @@ -69,8 +69,10 @@ "outputs": [], "source": [ "from typing import Any\n", + "\n", "import narwhals as nw\n", "\n", + "\n", "@nw.narwhalify\n", "def q2(\n", " region_ds: Any,\n", @@ -140,14 +142,14 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "region = dir_ + 'region.parquet'\n", - "nation = dir_ + 'nation.parquet'\n", - "customer = dir_ + 'customer.parquet'\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "orders = dir_ + 'orders.parquet'\n", - "supplier = dir_ + 'supplier.parquet'\n", - "part = dir_ + 'part.parquet'\n", - "partsupp = dir_ + 'partsupp.parquet'" + "region = dir_ + \"region.parquet\"\n", + "nation = dir_ + \"nation.parquet\"\n", + "customer = dir_ + \"customer.parquet\"\n", + "lineitem = dir_ + \"lineitem.parquet\"\n", + "orders = dir_ + \"orders.parquet\"\n", + "supplier = dir_ + \"supplier.parquet\"\n", + "part = dir_ + \"part.parquet\"\n", + "partsupp = dir_ + \"partsupp.parquet\"" ] }, { @@ -166,16 +168,18 @@ }, "outputs": [], "source": [ - "import pyarrow.parquet as pq\n", "import dask.dataframe as dd\n", + "import pyarrow.parquet as pq\n", "\n", "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", - " 'pyarrow': lambda x: pq.read_table(x),\n", - " 'dask': lambda x: dd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", + " \"pyarrow\": lambda x: pq.read_table(x),\n", + " \"dask\": lambda x: dd.read_parquet(x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"),\n", "}" ] }, @@ -222,7 +226,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q2(fn(region), fn(nation), fn(supplier), fn(part), fn(partsupp))\n", "results[tool] = timings.all_runs" @@ -261,7 +265,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q2(fn(region), fn(nation), fn(supplier), fn(part), fn(partsupp))\n", "results[tool] = timings.all_runs" @@ -300,7 +304,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q2(fn(region), fn(nation), fn(supplier), fn(part), fn(partsupp))\n", "results[tool] = timings.all_runs" @@ -339,7 +343,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q2(fn(region), fn(nation), fn(supplier), fn(part), fn(partsupp)).collect()\n", "results[tool] = timings.all_runs" @@ -360,7 +364,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'pyarrow'\n", + "tool = \"pyarrow\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q2(fn(region), fn(nation), fn(supplier), fn(part), fn(partsupp))\n", "results[tool] = timings.all_runs" @@ -381,7 +385,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'dask'\n", + "tool = \"dask\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q2(fn(region), fn(nation), fn(supplier), fn(part), fn(partsupp)).compute()\n", "results[tool] = timings.all_runs" @@ -403,8 +407,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/notebooks/q20/execute.ipynb b/tpch/notebooks/q20/execute.ipynb index aecb3a473..a9698c1ad 100644 --- a/tpch/notebooks/q20/execute.ipynb +++ b/tpch/notebooks/q20/execute.ipynb @@ -15,7 +15,7 @@ }, "outputs": [], "source": [ - "!pip uninstall apache-beam -y && pip install -U pandas polars pyarrow narwhals " + "!pip uninstall apache-beam -y && pip install -U pandas polars pyarrow narwhals" ] }, { @@ -55,24 +55,25 @@ }, "outputs": [], "source": [ + "from datetime import datetime\n", "from typing import Any\n", + "\n", "import narwhals as nw\n", - "from datetime import datetime\n", + "\n", "\n", "def q20(\n", " part_ds_raw: Any,\n", " partsupp_ds_raw: Any,\n", " nation_ds_raw: Any,\n", " lineitem_ds_raw: Any,\n", - " supplier_ds_raw: Any\n", + " supplier_ds_raw: Any,\n", ") -> Any:\n", - "\n", " part_ds = nw.from_native(part_ds_raw)\n", " nation_ds = nw.from_native(nation_ds_raw)\n", " partsupp_ds = nw.from_native(partsupp_ds_raw)\n", " lineitem_ds = nw.from_native(lineitem_ds_raw)\n", " supplier_ds = nw.from_native(supplier_ds_raw)\n", - " \n", + "\n", " var1 = datetime(1994, 1, 1)\n", " var2 = datetime(1995, 1, 1)\n", " var3 = \"CANADA\"\n", @@ -82,7 +83,7 @@ " lineitem_ds.filter(nw.col(\"l_shipdate\").is_between(var1, var2, closed=\"left\"))\n", " .group_by(\"l_partkey\", \"l_suppkey\")\n", " .agg((nw.col(\"l_quantity\").sum()).alias(\"sum_quantity\"))\n", - " .with_columns(sum_quantity = nw.col(\"sum_quantity\") * 0.5)\n", + " .with_columns(sum_quantity=nw.col(\"sum_quantity\") * 0.5)\n", " )\n", " query2 = nation_ds.filter(nw.col(\"n_name\") == var3)\n", " query3 = supplier_ds.join(query2, left_on=\"s_nationkey\", right_on=\"n_nationkey\")\n", @@ -103,7 +104,6 @@ " .sort(\"s_name\")\n", " )\n", "\n", - "\n", " return nw.to_native(result)" ] }, @@ -123,11 +123,11 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "nation = dir_ + 'nation.parquet'\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "supplier = dir_ + 'supplier.parquet'\n", - "part = dir_ + 'part.parquet'\n", - "partsupp = dir_ + 'partsupp.parquet'" + "nation = dir_ + \"nation.parquet\"\n", + "lineitem = dir_ + \"lineitem.parquet\"\n", + "supplier = dir_ + \"supplier.parquet\"\n", + "part = dir_ + \"part.parquet\"\n", + "partsupp = dir_ + \"partsupp.parquet\"" ] }, { @@ -146,10 +146,12 @@ "outputs": [], "source": [ "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", "}" ] }, @@ -193,7 +195,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q20(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -230,7 +232,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q20(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -267,7 +269,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q20(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -304,7 +306,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q20(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(supplier)).collect()\n", "results[tool] = timings.all_runs" @@ -324,8 +326,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/notebooks/q21/execute.ipynb b/tpch/notebooks/q21/execute.ipynb index b51b15dce..af12a424c 100755 --- a/tpch/notebooks/q21/execute.ipynb +++ b/tpch/notebooks/q21/execute.ipynb @@ -36,13 +36,12 @@ "outputs": [], "source": [ "from typing import Any\n", - "from datetime import date\n", - "\n", - "import narwhals as nw\n", "\n", "import pandas as pd\n", "import polars as pl\n", "\n", + "import narwhals as nw\n", + "\n", "pd.options.mode.copy_on_write = True\n", "pd.options.future.infer_string = True" ] @@ -66,10 +65,12 @@ "Q_NUM = 21\n", "\n", "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", "}" ] }, @@ -95,34 +96,28 @@ " orders_raw: Any,\n", " supplier_raw: Any,\n", ") -> Any:\n", - " \n", " lineitem = nw.from_native(lineitem_raw)\n", " nation = nw.from_native(nation_raw)\n", " orders = nw.from_native(orders_raw)\n", " supplier = nw.from_native(supplier_raw)\n", - " \n", + "\n", " var1 = \"SAUDI ARABIA\"\n", - " \n", - " \n", + "\n", " q1 = (\n", " lineitem.group_by(\"l_orderkey\")\n", - "# .agg(nw.col(\"l_suppkey\").len().alias(\"n_supp_by_order\"))\n", " .agg(nw.len().alias(\"n_supp_by_order\"))\n", " .filter(nw.col(\"n_supp_by_order\") > 1)\n", " .join(\n", " lineitem.filter(nw.col(\"l_receiptdate\") > nw.col(\"l_commitdate\")),\n", - "# on=\"l_orderkey\",\n", - " left_on=\"l_orderkey\", right_on=\"l_orderkey\",\n", + " left_on=\"l_orderkey\",\n", + " right_on=\"l_orderkey\",\n", " )\n", " )\n", "\n", " q_final = (\n", " q1.group_by(\"l_orderkey\")\n", - "# .agg(nw.col(\"l_suppkey\").len().alias(\"n_supp_by_order\"))\n", " .agg(nw.len().alias(\"n_supp_by_order\"))\n", - " .join(q1, left_on=\"l_orderkey\", right_on=\"l_orderkey\"\n", - " #on=\"l_orderkey\"\n", - " )\n", + " .join(q1, left_on=\"l_orderkey\", right_on=\"l_orderkey\")\n", " .join(supplier, left_on=\"l_suppkey\", right_on=\"s_suppkey\")\n", " .join(nation, left_on=\"s_nationkey\", right_on=\"n_nationkey\")\n", " .join(orders, left_on=\"l_orderkey\", right_on=\"o_orderkey\")\n", @@ -155,10 +150,10 @@ "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", "\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "nation = dir_ + 'nation.parquet'\n", - "orders = dir_ + 'orders.parquet'\n", - "supplier = dir_ + 'supplier.parquet'" + "lineitem = dir_ + \"lineitem.parquet\"\n", + "nation = dir_ + \"nation.parquet\"\n", + "orders = dir_ + \"orders.parquet\"\n", + "supplier = dir_ + \"supplier.parquet\"" ] }, { @@ -213,10 +208,15 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "\n", - "lineitem_raw, nation_raw, orders_raw, supplier_raw = fn(lineitem), fn(nation), fn(orders), fn(supplier)\n", + "lineitem_raw, nation_raw, orders_raw, supplier_raw = (\n", + " fn(lineitem),\n", + " fn(nation),\n", + " fn(orders),\n", + " fn(supplier),\n", + ")\n", "\n", "timings = %timeit -o -q q21(lineitem_raw, nation_raw, orders_raw, supplier_raw)\n", "results[tool] = timings.all_runs" @@ -255,9 +255,14 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", - "lineitem_raw, nation_raw, orders_raw, supplier_raw = fn(lineitem), fn(nation), fn(orders), fn(supplier)\n", + "lineitem_raw, nation_raw, orders_raw, supplier_raw = (\n", + " fn(lineitem),\n", + " fn(nation),\n", + " fn(orders),\n", + " fn(supplier),\n", + ")\n", "\n", "timings = %timeit -o -q q21(lineitem_raw, nation_raw, orders_raw, supplier_raw)\n", "results[tool] = timings.all_runs" @@ -296,10 +301,15 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "\n", - "lineitem_raw, nation_raw, orders_raw, supplier_raw = fn(lineitem), fn(nation), fn(orders), fn(supplier)\n", + "lineitem_raw, nation_raw, orders_raw, supplier_raw = (\n", + " fn(lineitem),\n", + " fn(nation),\n", + " fn(orders),\n", + " fn(supplier),\n", + ")\n", "timings = %timeit -o -q q21(lineitem_raw, nation_raw, orders_raw, supplier_raw)\n", "results[tool] = timings.all_runs" ] @@ -337,10 +347,15 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "\n", - "lineitem_raw, nation_raw, orders_raw, supplier_raw = fn(lineitem), fn(nation), fn(orders), fn(supplier)\n", + "lineitem_raw, nation_raw, orders_raw, supplier_raw = (\n", + " fn(lineitem),\n", + " fn(nation),\n", + " fn(orders),\n", + " fn(supplier),\n", + ")\n", "timings = %timeit -o -q q21(lineitem_raw, nation_raw, orders_raw, supplier_raw).collect()\n", "results[tool] = timings.all_runs" ] @@ -379,29 +394,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16", - "metadata": { - "papermill": { - "duration": 0.02616, - "end_time": "2024-06-20T09:46:18.666732", - "exception": false, - "start_time": "2024-06-20T09:46:18.640572", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from pprint import pprint\n", "\n", - "pprint(results)" + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/notebooks/q3/execute.ipynb b/tpch/notebooks/q3/execute.ipynb index 80178cae1..b81135fc3 100755 --- a/tpch/notebooks/q3/execute.ipynb +++ b/tpch/notebooks/q3/execute.ipynb @@ -49,14 +49,15 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Any\n", "from datetime import date\n", + "from typing import Any\n", + "\n", "\n", "def q3_pandas_native(\n", " customer_ds: Any,\n", " line_item_ds: Any,\n", " orders_ds: Any,\n", - "):\n", + ") -> Any:\n", " var1 = \"BUILDING\"\n", " var2 = date(1995, 3, 15)\n", "\n", @@ -69,18 +70,15 @@ " jn2 = jn2[jn2[\"l_shipdate\"] > var2]\n", " jn2[\"revenue\"] = jn2.l_extendedprice * (1 - jn2.l_discount)\n", "\n", - " gb = jn2.groupby(\n", - " [\"o_orderkey\", \"o_orderdate\", \"o_shippriority\"], as_index=False\n", - " )\n", + " gb = jn2.groupby([\"o_orderkey\", \"o_orderdate\", \"o_shippriority\"], as_index=False)\n", " agg = gb[\"revenue\"].sum()\n", "\n", " sel = agg.loc[:, [\"o_orderkey\", \"revenue\", \"o_orderdate\", \"o_shippriority\"]]\n", " sel = sel.rename({\"o_orderkey\": \"l_orderkey\"}, axis=\"columns\")\n", "\n", " sorted = sel.sort_values(by=[\"revenue\", \"o_orderdate\"], ascending=[False, True])\n", - " result_df = sorted.head(10)\n", "\n", - " return result_df # type: ignore[no-any-return]" + " return sorted.head(10) # type: ignore[no-any-return]" ] }, { @@ -99,10 +97,12 @@ }, "outputs": [], "source": [ - "from typing import Any\n", "from datetime import datetime\n", + "from typing import Any\n", + "\n", "import narwhals as nw\n", "\n", + "\n", "def q3(\n", " customer_ds_raw: Any,\n", " line_item_ds_raw: Any,\n", @@ -122,7 +122,8 @@ " .filter(\n", " nw.col(\"o_orderdate\") < var_2,\n", " nw.col(\"l_shipdate\") > var_1,\n", - " ).with_columns(\n", + " )\n", + " .with_columns(\n", " (nw.col(\"l_extendedprice\") * (1 - nw.col(\"l_discount\"))).alias(\"revenue\")\n", " )\n", " .group_by([\"o_orderkey\", \"o_orderdate\", \"o_shippriority\"])\n", @@ -150,16 +151,16 @@ "outputs": [], "source": [ "from typing import Any\n", - "from datetime import datetime\n", - "import narwhals as nw\n", + "\n", "import ibis\n", "\n", + "\n", "def q3_ibis(\n", " customer: Any,\n", " lineitem: Any,\n", " orders: Any,\n", " *,\n", - " tool,\n", + " tool: str,\n", ") -> Any:\n", " var1 = \"BUILDING\"\n", " var2 = date(1995, 3, 15)\n", @@ -186,9 +187,9 @@ " .order_by(ibis.desc(\"revenue\"), \"o_orderdate\")\n", " .limit(10)\n", " )\n", - " if tool == 'pandas':\n", + " if tool == \"pandas\":\n", " return q_final.to_pandas()\n", - " if tool == 'polars':\n", + " if tool == \"polars\":\n", " return q_final.to_polars()\n", " raise ValueError(\"expected pandas or polars\")" ] @@ -210,14 +211,14 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "region = dir_ + 'region.parquet'\n", - "nation = dir_ + 'nation.parquet'\n", - "customer = dir_ + 'customer.parquet'\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "orders = dir_ + 'orders.parquet'\n", - "supplier = dir_ + 'supplier.parquet'\n", - "part = dir_ + 'part.parquet'\n", - "partsupp = dir_ + 'partsupp.parquet'" + "region = dir_ + \"region.parquet\"\n", + "nation = dir_ + \"nation.parquet\"\n", + "customer = dir_ + \"customer.parquet\"\n", + "lineitem = dir_ + \"lineitem.parquet\"\n", + "orders = dir_ + \"orders.parquet\"\n", + "supplier = dir_ + \"supplier.parquet\"\n", + "part = dir_ + \"part.parquet\"\n", + "partsupp = dir_ + \"partsupp.parquet\"" ] }, { @@ -236,18 +237,20 @@ }, "outputs": [], "source": [ - "import ibis\n", - "\n", "con_pd = ibis.pandas.connect()\n", "con_pl = ibis.polars.connect()\n", "\n", "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'pandas[pyarrow][ibis]': lambda x: con_pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", - " 'polars[lazy][ibis]': lambda x: con_pl.read_parquet(x),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"pandas[pyarrow][ibis]\": lambda x: con_pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", + " \"polars[lazy][ibis]\": lambda x: con_pl.read_parquet(x),\n", "}" ] }, @@ -276,7 +279,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'pandas[pyarrow][ibis]'\n", + "tool = \"pandas[pyarrow][ibis]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q3_ibis(fn(customer), fn(lineitem), fn(orders), tool='pandas')\n", "results[tool] = timings.all_runs" @@ -297,7 +300,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'polars[lazy][ibis]'\n", + "tool = \"polars[lazy][ibis]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q3_ibis(fn(customer), fn(lineitem), fn(orders), tool='polars')\n", "results[tool] = timings.all_runs" @@ -318,10 +321,10 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q3_pandas_native(fn(customer), fn(lineitem), fn(orders))\n", - "results[tool+'[native]'] = timings.all_runs" + "results[tool + \"[native]\"] = timings.all_runs" ] }, { @@ -357,7 +360,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q3(fn(customer), fn(lineitem), fn(orders))\n", "results[tool] = timings.all_runs" @@ -396,7 +399,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q3(fn(customer), fn(lineitem), fn(orders))\n", "results[tool] = timings.all_runs" @@ -435,7 +438,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q3(fn(customer), fn(lineitem), fn(orders))\n", "results[tool] = timings.all_runs" @@ -474,7 +477,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q3(fn(customer), fn(lineitem), fn(orders)).collect()\n", "results[tool] = timings.all_runs" @@ -496,8 +499,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/notebooks/q4/execute.ipynb b/tpch/notebooks/q4/execute.ipynb index df07c9c5f..b0a55e345 100755 --- a/tpch/notebooks/q4/execute.ipynb +++ b/tpch/notebooks/q4/execute.ipynb @@ -52,6 +52,7 @@ "from datetime import date\n", "from typing import Any\n", "\n", + "\n", "def q4_pandas_native(\n", " line_item_ds: Any,\n", " orders_ds: Any,\n", @@ -72,9 +73,7 @@ " gb = jn.groupby(\"o_orderpriority\", as_index=False)\n", " agg = gb.agg(order_count=pd.NamedAgg(column=\"o_orderkey\", aggfunc=\"count\"))\n", "\n", - " result_df = agg.sort_values([\"o_orderpriority\"])\n", - "\n", - " return result_df # type: ignore[no-any-return]" + " return agg.sort_values([\"o_orderpriority\"]) # type: ignore[no-any-return]" ] }, { @@ -93,10 +92,12 @@ }, "outputs": [], "source": [ - "from typing import Any\n", "from datetime import datetime\n", + "from typing import Any\n", + "\n", "import narwhals as nw\n", "\n", + "\n", "def q4(\n", " lineitem_ds_raw: Any,\n", " orders_ds_raw: Any,\n", @@ -112,7 +113,8 @@ " .filter(\n", " nw.col(\"o_orderdate\").is_between(var_1, var_2, closed=\"left\"),\n", " nw.col(\"l_commitdate\") < nw.col(\"l_receiptdate\"),\n", - " ).unique(subset=[\"o_orderpriority\", \"l_orderkey\"])\n", + " )\n", + " .unique(subset=[\"o_orderpriority\", \"l_orderkey\"])\n", " .group_by(\"o_orderpriority\")\n", " .agg(nw.len().alias(\"order_count\"))\n", " .sort(by=\"o_orderpriority\")\n", @@ -130,15 +132,11 @@ "outputs": [], "source": [ "from typing import Any\n", - "from datetime import datetime\n", + "\n", "import ibis\n", "\n", - "def q4_ibis(\n", - " lineitem: Any,\n", - " orders: Any,\n", - " *,\n", - " tool: str\n", - ") -> Any:\n", + "\n", + "def q4_ibis(lineitem: Any, orders: Any, *, tool: str) -> Any:\n", " var1 = datetime(1993, 7, 1)\n", " var2 = datetime(1993, 10, 1)\n", "\n", @@ -151,9 +149,9 @@ " .agg(order_count=ibis._.count())\n", " .order_by(\"o_orderpriority\")\n", " )\n", - " if tool == 'pandas':\n", + " if tool == \"pandas\":\n", " return q_final.to_pandas()\n", - " if tool == 'polars':\n", + " if tool == \"polars\":\n", " return q_final.to_polars()\n", " raise ValueError(\"expected pandas or polars\")" ] @@ -175,14 +173,14 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "region = dir_ + 'region.parquet'\n", - "nation = dir_ + 'nation.parquet'\n", - "customer = dir_ + 'customer.parquet'\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "orders = dir_ + 'orders.parquet'\n", - "supplier = dir_ + 'supplier.parquet'\n", - "part = dir_ + 'part.parquet'\n", - "partsupp = dir_ + 'partsupp.parquet'" + "region = dir_ + \"region.parquet\"\n", + "nation = dir_ + \"nation.parquet\"\n", + "customer = dir_ + \"customer.parquet\"\n", + "lineitem = dir_ + \"lineitem.parquet\"\n", + "orders = dir_ + \"orders.parquet\"\n", + "supplier = dir_ + \"supplier.parquet\"\n", + "part = dir_ + \"part.parquet\"\n", + "partsupp = dir_ + \"partsupp.parquet\"" ] }, { @@ -201,18 +199,20 @@ }, "outputs": [], "source": [ - "import ibis\n", - "\n", "con_pd = ibis.pandas.connect()\n", "con_pl = ibis.polars.connect()\n", "\n", "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'pandas[pyarrow][ibis]': lambda x: con_pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", - " 'polars[lazy][ibis]': lambda x: con_pl.read_parquet(x),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"pandas[pyarrow][ibis]\": lambda x: con_pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", + " \"polars[lazy][ibis]\": lambda x: con_pl.read_parquet(x),\n", "}" ] }, @@ -241,7 +241,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'polars[lazy][ibis]'\n", + "tool = \"polars[lazy][ibis]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q4_ibis(fn(lineitem), fn(orders), tool='polars')\n", "results[tool] = timings.all_runs" @@ -262,10 +262,10 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q4_pandas_native(fn(lineitem), fn(orders))\n", - "results[tool+'[native]'] = timings.all_runs" + "results[tool + \"[native]\"] = timings.all_runs" ] }, { @@ -301,7 +301,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q4(fn(lineitem), fn(orders))\n", "results[tool] = timings.all_runs" @@ -340,7 +340,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q4(fn(lineitem), fn(orders))\n", "results[tool] = timings.all_runs" @@ -379,7 +379,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q4(fn(lineitem), fn(orders))\n", "results[tool] = timings.all_runs" @@ -418,7 +418,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q4(fn(lineitem), fn(orders)).collect()\n", "results[tool] = timings.all_runs" @@ -440,8 +440,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/notebooks/q5/execute.ipynb b/tpch/notebooks/q5/execute.ipynb index 5f6df9bbc..da0cae78b 100755 --- a/tpch/notebooks/q5/execute.ipynb +++ b/tpch/notebooks/q5/execute.ipynb @@ -49,8 +49,9 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Any\n", "from datetime import date\n", + "from typing import Any\n", + "\n", "\n", "def q5_pandas_native(\n", " region_ds: Any,\n", @@ -79,9 +80,8 @@ " jn5[\"revenue\"] = jn5.l_extendedprice * (1.0 - jn5.l_discount)\n", "\n", " gb = jn5.groupby(\"n_name\", as_index=False)[\"revenue\"].sum()\n", - " result_df = gb.sort_values(\"revenue\", ascending=False)\n", "\n", - " return result_df # type: ignore[no-any-return]" + " return gb.sort_values(\"revenue\", ascending=False) # type: ignore[no-any-return]" ] }, { @@ -91,10 +91,12 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Any\n", "from datetime import datetime\n", + "from typing import Any\n", + "\n", "import narwhals as nw\n", "\n", + "\n", "def q5(\n", " region_ds_raw: Any,\n", " nation_ds_raw: Any,\n", @@ -126,7 +128,7 @@ " )\n", " .filter(\n", " nw.col(\"r_name\") == var_1,\n", - " nw.col(\"o_orderdate\").is_between(var_2, var_3, closed=\"left\")\n", + " nw.col(\"o_orderdate\").is_between(var_2, var_3, closed=\"left\"),\n", " )\n", " .with_columns(\n", " (nw.col(\"l_extendedprice\") * (1 - nw.col(\"l_discount\"))).alias(\"revenue\")\n", @@ -147,10 +149,10 @@ "outputs": [], "source": [ "from typing import Any\n", - "from datetime import datetime\n", - "import narwhals as nw\n", + "\n", "import ibis\n", "\n", + "\n", "def q5_ibis(\n", " region: Any,\n", " nation: Any,\n", @@ -183,9 +185,9 @@ " .order_by(ibis.desc(\"revenue\"))\n", " )\n", "\n", - " if tool == 'pandas':\n", + " if tool == \"pandas\":\n", " return q_final.to_pandas()\n", - " if tool == 'polars':\n", + " if tool == \"polars\":\n", " return q_final.to_polars()\n", " raise ValueError(\"expected pandas or polars\")" ] @@ -207,14 +209,14 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "region = dir_ + 'region.parquet'\n", - "nation = dir_ + 'nation.parquet'\n", - "customer = dir_ + 'customer.parquet'\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "orders = dir_ + 'orders.parquet'\n", - "supplier = dir_ + 'supplier.parquet'\n", - "part = dir_ + 'part.parquet'\n", - "partsupp = dir_ + 'partsupp.parquet'" + "region = dir_ + \"region.parquet\"\n", + "nation = dir_ + \"nation.parquet\"\n", + "customer = dir_ + \"customer.parquet\"\n", + "lineitem = dir_ + \"lineitem.parquet\"\n", + "orders = dir_ + \"orders.parquet\"\n", + "supplier = dir_ + \"supplier.parquet\"\n", + "part = dir_ + \"part.parquet\"\n", + "partsupp = dir_ + \"partsupp.parquet\"" ] }, { @@ -233,18 +235,20 @@ }, "outputs": [], "source": [ - "import ibis\n", - "\n", "con_pd = ibis.pandas.connect()\n", "con_pl = ibis.polars.connect()\n", "\n", "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'pandas[pyarrow][ibis]': lambda x: con_pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", - " 'polars[lazy][ibis]': lambda x: con_pl.read_parquet(x),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"pandas[pyarrow][ibis]\": lambda x: con_pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", + " \"polars[lazy][ibis]\": lambda x: con_pl.read_parquet(x),\n", "}" ] }, @@ -273,7 +277,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'polars[lazy][ibis]'\n", + "tool = \"polars[lazy][ibis]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q5_ibis(fn(region), fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier), tool='polars')\n", "results[tool] = timings.all_runs" @@ -294,10 +298,10 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q5_pandas_native(fn(region), fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))\n", - "results[tool+'[native]'] = timings.all_runs" + "results[tool + \"[native]\"] = timings.all_runs" ] }, { @@ -333,7 +337,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q5(fn(region), fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -372,7 +376,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q5(fn(region), fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -411,7 +415,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q5(fn(region), fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -450,7 +454,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q5(fn(region), fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier)).collect()\n", "results[tool] = timings.all_runs" @@ -472,8 +476,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/notebooks/q6/execute.ipynb b/tpch/notebooks/q6/execute.ipynb index b101aa98d..5abcb65f0 100755 --- a/tpch/notebooks/q6/execute.ipynb +++ b/tpch/notebooks/q6/execute.ipynb @@ -50,6 +50,7 @@ "source": [ "from datetime import date\n", "\n", + "\n", "def q6_pandas_native(line_item_ds):\n", " var1 = date(1994, 1, 1)\n", " var2 = date(1995, 1, 1)\n", @@ -66,9 +67,8 @@ " ]\n", "\n", " result_value = (flineitem[\"l_extendedprice\"] * flineitem[\"l_discount\"]).sum()\n", - " result_df = pd.DataFrame({\"revenue\": [result_value]})\n", "\n", - " return result_df" + " return pd.DataFrame({\"revenue\": [result_value]})" ] }, { @@ -87,10 +87,11 @@ }, "outputs": [], "source": [ - "from typing import Any\n", "from datetime import datetime\n", + "\n", "import narwhals as nw\n", "\n", + "\n", "def q6(line_item_raw) -> None:\n", " var_1 = datetime(1994, 1, 1)\n", " var_2 = datetime(1995, 1, 1)\n", @@ -103,12 +104,11 @@ " nw.col(\"l_shipdate\").is_between(var_1, var_2, closed=\"left\"),\n", " nw.col(\"l_discount\").is_between(0.05, 0.07),\n", " nw.col(\"l_quantity\") < var_3,\n", - " ).with_columns(\n", - " (nw.col(\"l_extendedprice\") * nw.col(\"l_discount\")).alias(\"revenue\")\n", " )\n", + " .with_columns((nw.col(\"l_extendedprice\") * nw.col(\"l_discount\")).alias(\"revenue\"))\n", " .select(nw.sum(\"revenue\"))\n", " )\n", - " return nw.to_native(result)\n" + " return nw.to_native(result)" ] }, { @@ -118,10 +118,6 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Any\n", - "from datetime import datetime\n", - "import narwhals as nw\n", - "\n", "def q6_ibis(lineitem, *, tool: str) -> None:\n", " var1 = datetime(1994, 1, 1)\n", " var2 = datetime(1995, 1, 1)\n", @@ -138,12 +134,12 @@ " .mutate(revenue=ibis._[\"l_extendedprice\"] * (ibis._[\"l_discount\"]))\n", " .agg(revenue=ibis._[\"revenue\"].sum())\n", " )\n", - " \n", - " if tool == 'pandas':\n", + "\n", + " if tool == \"pandas\":\n", " return q_final.to_pandas()\n", - " if tool == 'polars':\n", + " if tool == \"polars\":\n", " return q_final.to_polars()\n", - " raise ValueError(\"expected pandas or polars\")\n" + " raise ValueError(\"expected pandas or polars\")" ] }, { @@ -163,14 +159,14 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "region = dir_ + 'region.parquet'\n", - "nation = dir_ + 'nation.parquet'\n", - "customer = dir_ + 'customer.parquet'\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "orders = dir_ + 'orders.parquet'\n", - "supplier = dir_ + 'supplier.parquet'\n", - "part = dir_ + 'part.parquet'\n", - "partsupp = dir_ + 'partsupp.parquet'" + "region = dir_ + \"region.parquet\"\n", + "nation = dir_ + \"nation.parquet\"\n", + "customer = dir_ + \"customer.parquet\"\n", + "lineitem = dir_ + \"lineitem.parquet\"\n", + "orders = dir_ + \"orders.parquet\"\n", + "supplier = dir_ + \"supplier.parquet\"\n", + "part = dir_ + \"part.parquet\"\n", + "partsupp = dir_ + \"partsupp.parquet\"" ] }, { @@ -195,12 +191,16 @@ "con_pl = ibis.polars.connect()\n", "\n", "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'pandas[pyarrow][ibis]': lambda x: con_pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", - " 'polars[lazy][ibis]': lambda x: con_pl.read_parquet(x),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"pandas[pyarrow][ibis]\": lambda x: con_pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", + " \"polars[lazy][ibis]\": lambda x: con_pl.read_parquet(x),\n", "}" ] }, @@ -229,7 +229,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'pandas[pyarrow][ibis]'\n", + "tool = \"pandas[pyarrow][ibis]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q6_ibis(fn(lineitem), tool='pandas')\n", "results[tool] = timings.all_runs" @@ -250,7 +250,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'polars[lazy][ibis]'\n", + "tool = \"polars[lazy][ibis]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q6_ibis(fn(lineitem), tool='polars')\n", "results[tool] = timings.all_runs" @@ -271,10 +271,10 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q6_pandas_native(fn(lineitem))\n", - "results[tool+'[native]'] = timings.all_runs" + "results[tool + \"[native]\"] = timings.all_runs" ] }, { @@ -310,7 +310,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q6(fn(lineitem))\n", "results[tool] = timings.all_runs" @@ -349,7 +349,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q6(fn(lineitem))\n", "results[tool] = timings.all_runs" @@ -388,7 +388,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q6(fn(lineitem))\n", "results[tool] = timings.all_runs" @@ -427,7 +427,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q6(fn(lineitem)).collect()\n", "results[tool] = timings.all_runs" @@ -449,8 +449,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/notebooks/q7/execute.ipynb b/tpch/notebooks/q7/execute.ipynb index 1213043b0..8711d7505 100755 --- a/tpch/notebooks/q7/execute.ipynb +++ b/tpch/notebooks/q7/execute.ipynb @@ -49,10 +49,13 @@ "metadata": {}, "outputs": [], "source": [ + "from datetime import date\n", + "from datetime import datetime\n", "from typing import Any\n", - "from datetime import datetime, date\n", + "\n", "import narwhals as nw\n", "\n", + "\n", "def q7_pandas_native(\n", " nation_ds,\n", " customer_ds,\n", @@ -96,9 +99,7 @@ " gb = total.groupby([\"supp_nation\", \"cust_nation\", \"l_year\"], as_index=False)\n", " agg = gb.agg(revenue=pd.NamedAgg(column=\"volume\", aggfunc=\"sum\"))\n", "\n", - " result_df = agg.sort_values(by=[\"supp_nation\", \"cust_nation\", \"l_year\"])\n", - "\n", - " return result_df # type: ignore[no-any-return]" + " return agg.sort_values(by=[\"supp_nation\", \"cust_nation\", \"l_year\"]) # type: ignore[no-any-return]" ] }, { @@ -117,10 +118,6 @@ }, "outputs": [], "source": [ - "from typing import Any\n", - "from datetime import datetime\n", - "import narwhals as nw\n", - "\n", "def q7(\n", " nation_ds,\n", " customer_ds,\n", @@ -171,7 +168,7 @@ " .agg(nw.sum(\"volume\").alias(\"revenue\"))\n", " .sort(by=[\"supp_nation\", \"cust_nation\", \"l_year\"])\n", " )\n", - " return nw.to_native(result)\n" + " return nw.to_native(result)" ] }, { @@ -181,18 +178,11 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Any\n", - "from datetime import datetime\n", "import ibis\n", "\n", + "\n", "def q7_ibis(\n", - " nation: Any,\n", - " customer: Any,\n", - " lineitem: Any,\n", - " orders: Any,\n", - " supplier: Any,\n", - " *,\n", - " tool: str\n", + " nation: Any, customer: Any, lineitem: Any, orders: Any, supplier: Any, *, tool: str\n", ") -> None:\n", " var1 = \"FRANCE\"\n", " var2 = \"GERMANY\"\n", @@ -234,9 +224,9 @@ " .order_by(\"supp_nation\", \"cust_nation\", \"l_year\")\n", " )\n", "\n", - " if tool == 'pandas':\n", + " if tool == \"pandas\":\n", " return q_final.to_pandas()\n", - " if tool == 'polars':\n", + " if tool == \"polars\":\n", " return q_final.to_polars()\n", " raise ValueError(\"expected pandas or polars\")" ] @@ -258,14 +248,14 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "region = dir_ + 'region.parquet'\n", - "nation = dir_ + 'nation.parquet'\n", - "customer = dir_ + 'customer.parquet'\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "orders = dir_ + 'orders.parquet'\n", - "supplier = dir_ + 'supplier.parquet'\n", - "part = dir_ + 'part.parquet'\n", - "partsupp = dir_ + 'partsupp.parquet'" + "region = dir_ + \"region.parquet\"\n", + "nation = dir_ + \"nation.parquet\"\n", + "customer = dir_ + \"customer.parquet\"\n", + "lineitem = dir_ + \"lineitem.parquet\"\n", + "orders = dir_ + \"orders.parquet\"\n", + "supplier = dir_ + \"supplier.parquet\"\n", + "part = dir_ + \"part.parquet\"\n", + "partsupp = dir_ + \"partsupp.parquet\"" ] }, { @@ -284,18 +274,20 @@ }, "outputs": [], "source": [ - "import ibis\n", - "\n", "con_pd = ibis.pandas.connect()\n", "con_pl = ibis.polars.connect()\n", "\n", "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'pandas[pyarrow][ibis]': lambda x: con_pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", - " 'polars[lazy][ibis]': lambda x: con_pl.read_parquet(x),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"pandas[pyarrow][ibis]\": lambda x: con_pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", + " \"polars[lazy][ibis]\": lambda x: con_pl.read_parquet(x),\n", "}" ] }, @@ -324,7 +316,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'pandas[pyarrow][ibis]'\n", + "tool = \"pandas[pyarrow][ibis]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q7_ibis(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier), tool='pandas')\n", "results[tool] = timings.all_runs" @@ -345,7 +337,7 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'polars[lazy][ibis]'\n", + "tool = \"polars[lazy][ibis]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q7_ibis(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier), tool='polars')\n", "results[tool] = timings.all_runs" @@ -366,10 +358,10 @@ "metadata": {}, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q7_pandas_native(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))\n", - "results[tool+'[native]'] = timings.all_runs" + "results[tool + \"[native]\"] = timings.all_runs" ] }, { @@ -405,7 +397,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q7(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -444,7 +436,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q7(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -483,7 +475,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q7(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -522,7 +514,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q7(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier)).collect()\n", "results[tool] = timings.all_runs" @@ -544,8 +536,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/notebooks/q8/execute.ipynb b/tpch/notebooks/q8/execute.ipynb deleted file mode 100755 index b10b87907..000000000 --- a/tpch/notebooks/q8/execute.ipynb +++ /dev/null @@ -1,502 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "0", - "metadata": { - "papermill": { - "duration": 33.390992, - "end_time": "2024-03-22T17:24:15.601719", - "exception": false, - "start_time": "2024-03-22T17:23:42.210727", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "!pip install -U polars pyarrow " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1", - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -U narwhals" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2", - "metadata": { - "papermill": { - "duration": 0.907754, - "end_time": "2024-03-22T17:24:39.053873", - "exception": false, - "start_time": "2024-03-22T17:24:38.146119", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "import pandas as pd\n", - "import polars as pl\n", - "\n", - "pd.options.mode.copy_on_write = True\n", - "pd.options.future.infer_string = True" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Any\n", - "from datetime import datetime, date\n", - "import narwhals as nw\n", - "\n", - "def q8_pandas_native(\n", - " nation_ds,\n", - " customer_ds,\n", - " line_item_ds,\n", - " orders_ds,\n", - " supplier_ds,\n", - ") -> None:\n", - " var1 = \"FRANCE\"\n", - " var2 = \"GERMANY\"\n", - " var3 = date(1995, 1, 1)\n", - " var4 = date(1996, 12, 31)\n", - "\n", - " n1 = nation_ds[(nation_ds[\"n_name\"] == var1)]\n", - " n2 = nation_ds[(nation_ds[\"n_name\"] == var2)]\n", - "\n", - " # Part 1\n", - " jn1 = customer_ds.merge(n1, left_on=\"c_nationkey\", right_on=\"n_nationkey\")\n", - " jn2 = jn1.merge(orders_ds, left_on=\"c_custkey\", right_on=\"o_custkey\")\n", - " jn2 = jn2.rename({\"n_name\": \"cust_nation\"}, axis=\"columns\")\n", - " jn3 = jn2.merge(line_item_ds, left_on=\"o_orderkey\", right_on=\"l_orderkey\")\n", - " jn4 = jn3.merge(supplier_ds, left_on=\"l_suppkey\", right_on=\"s_suppkey\")\n", - " jn5 = jn4.merge(n2, left_on=\"s_nationkey\", right_on=\"n_nationkey\")\n", - " df1 = jn5.rename({\"n_name\": \"supp_nation\"}, axis=\"columns\")\n", - "\n", - " # Part 2\n", - " jn1 = customer_ds.merge(n2, left_on=\"c_nationkey\", right_on=\"n_nationkey\")\n", - " jn2 = jn1.merge(orders_ds, left_on=\"c_custkey\", right_on=\"o_custkey\")\n", - " jn2 = jn2.rename({\"n_name\": \"cust_nation\"}, axis=\"columns\")\n", - " jn3 = jn2.merge(line_item_ds, left_on=\"o_orderkey\", right_on=\"l_orderkey\")\n", - " jn4 = jn3.merge(supplier_ds, left_on=\"l_suppkey\", right_on=\"s_suppkey\")\n", - " jn5 = jn4.merge(n1, left_on=\"s_nationkey\", right_on=\"n_nationkey\")\n", - " df2 = jn5.rename({\"n_name\": \"supp_nation\"}, axis=\"columns\")\n", - "\n", - " # Combine\n", - " total = pd.concat([df1, df2])\n", - "\n", - " total = total[(total[\"l_shipdate\"] >= var3) & (total[\"l_shipdate\"] <= var4)]\n", - " total[\"volume\"] = total[\"l_extendedprice\"] * (1.0 - total[\"l_discount\"])\n", - " total[\"l_year\"] = total[\"l_shipdate\"].dt.year\n", - "\n", - " gb = total.groupby([\"supp_nation\", \"cust_nation\", \"l_year\"], as_index=False)\n", - " agg = gb.agg(revenue=pd.NamedAgg(column=\"volume\", aggfunc=\"sum\"))\n", - "\n", - " result_df = agg.sort_values(by=[\"supp_nation\", \"cust_nation\", \"l_year\"])\n", - "\n", - " return result_df # type: ignore[no-any-return]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4", - "metadata": { - "papermill": { - "duration": 0.021725, - "end_time": "2024-03-22T17:24:39.080999", - "exception": false, - "start_time": "2024-03-22T17:24:39.059274", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from typing import Any\n", - "from datetime import datetime\n", - "import narwhals as nw\n", - "\n", - "def q8(\n", - " nation_ds_raw: Any,\n", - " customer_ds_raw: Any,\n", - " line_item_ds_raw: Any,\n", - " orders_ds_raw: Any,\n", - " supplier_ds_raw: Any,\n", - " part_ds_raw: Any,\n", - ") -> None:\n", - " nation_ds = nw.from_native(nation_ds_raw)\n", - " customer_ds = nw.from_native(customer_ds_raw)\n", - " line_item_ds = nw.from_native(line_item_ds_raw)\n", - " orders_ds = nw.from_native(orders_ds_raw)\n", - " supplier_ds = nw.from_native(supplier_ds_raw)\n", - " part_ds = nw.from_native(part_ds_raw)\n", - "\n", - " n1 = nation_ds.select(\"n_nationkey\", \"n_regionkey\")\n", - " n2 = nation_ds.select(\"n_nationkey\", \"n_name\")\n", - "\n", - " result = (\n", - " part_ds.join(line_item_ds, left_on=\"p_partkey\", right_on=\"l_partkey\")\n", - " .join(supplier_ds, left_on=\"l_suppkey\", right_on=\"s_suppkey\")\n", - " .join(orders_ds, left_on=\"l_orderkey\", right_on=\"o_orderkey\")\n", - " .join(customer_ds, left_on=\"o_custkey\", right_on=\"c_custkey\")\n", - " .join(n1, left_on=\"c_nationkey\", right_on=\"n_nationkey\")\n", - " .join(region_ds, left_on=\"n_regionkey\", right_on=\"r_regionkey\")\n", - " .filter(nw.col(\"r_name\") == \"AMERICA\")\n", - " .join(n2, left_on=\"s_nationkey\", right_on=\"n_nationkey\")\n", - " .filter(\n", - " nw.col(\"o_orderdate\")>= date(1995, 1, 1),\n", - " nw.col('o_orderdate')<=date(1996, 12, 31)\n", - " )\n", - " .filter(nw.col(\"p_type\") == \"ECONOMY ANODIZED STEEL\")\n", - " .select(\n", - " nw.col(\"o_orderdate\").dt.year().alias(\"o_year\"),\n", - " (nw.col(\"l_extendedprice\") * (1 - nw.col(\"l_discount\"))).alias(\"volume\"),\n", - " nw.col(\"n_name\").alias(\"nation\"),\n", - " )\n", - " .with_columns(\n", - " nw.when(nw.col(\"nation\") == \"BRAZIL\")\n", - " .then(nw.col(\"volume\"))\n", - " .otherwise(0)\n", - " .alias(\"_tmp\")\n", - " )\n", - " .group_by(\"o_year\")\n", - " .agg((nw.sum(\"_tmp\") / nw.sum(\"volume\")).round(2).alias(\"mkt_share\"))\n", - " .sort(\"o_year\")\n", - " )\n", - " \n", - " return nw.to_native(result)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5", - "metadata": { - "papermill": { - "duration": 0.013325, - "end_time": "2024-03-22T17:24:39.099766", - "exception": false, - "start_time": "2024-03-22T17:24:39.086441", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "region = dir_ + 'region.parquet'\n", - "nation = dir_ + 'nation.parquet'\n", - "customer = dir_ + 'customer.parquet'\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "orders = dir_ + 'orders.parquet'\n", - "supplier = dir_ + 'supplier.parquet'\n", - "part = dir_ + 'part.parquet'\n", - "partsupp = dir_ + 'partsupp.parquet'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6", - "metadata": { - "papermill": { - "duration": 0.014284, - "end_time": "2024-03-22T17:24:39.119737", - "exception": false, - "start_time": "2024-03-22T17:24:39.105453", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7", - "metadata": {}, - "outputs": [], - "source": [ - "results = {}" - ] - }, - { - "cell_type": "markdown", - "id": "8", - "metadata": {}, - "source": [ - "## pandas, pyarrow dtypes, native" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9", - "metadata": {}, - "outputs": [], - "source": [ - "tool = 'pandas[pyarrow]'\n", - "fn = IO_FUNCS[tool]\n", - "timings = %timeit -o -q q7_pandas_native(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))\n", - "results[tool+'[native]'] = timings.all_runs" - ] - }, - { - "cell_type": "markdown", - "id": "10", - "metadata": { - "papermill": { - "duration": 0.005113, - "end_time": "2024-03-22T17:24:39.130472", - "exception": false, - "start_time": "2024-03-22T17:24:39.125359", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "## pandas via Narwhals" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "11", - "metadata": { - "papermill": { - "duration": 196.786925, - "end_time": "2024-03-22T17:27:55.922832", - "exception": false, - "start_time": "2024-03-22T17:24:39.135907", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "tool = 'pandas'\n", - "fn = IO_FUNCS[tool]\n", - "timings = %timeit -o -q q7(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))\n", - "results[tool] = timings.all_runs" - ] - }, - { - "cell_type": "markdown", - "id": "12", - "metadata": { - "papermill": { - "duration": 0.005184, - "end_time": "2024-03-22T17:27:55.933407", - "exception": false, - "start_time": "2024-03-22T17:27:55.928223", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "## pandas, pyarrow dtypes, via Narwhals" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "13", - "metadata": { - "papermill": { - "duration": 158.748353, - "end_time": "2024-03-22T17:30:34.688289", - "exception": false, - "start_time": "2024-03-22T17:27:55.939936", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "tool = 'pandas[pyarrow]'\n", - "fn = IO_FUNCS[tool]\n", - "timings = %timeit -o -q q7(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))\n", - "results[tool] = timings.all_runs" - ] - }, - { - "cell_type": "markdown", - "id": "14", - "metadata": { - "papermill": { - "duration": 0.005773, - "end_time": "2024-03-22T17:30:34.700300", - "exception": false, - "start_time": "2024-03-22T17:30:34.694527", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "## Polars read_parquet" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "15", - "metadata": { - "papermill": { - "duration": 37.821116, - "end_time": "2024-03-22T17:31:12.527466", - "exception": false, - "start_time": "2024-03-22T17:30:34.706350", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "tool = 'polars[eager]'\n", - "fn = IO_FUNCS[tool]\n", - "timings = %timeit -o -q q7(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))\n", - "results[tool] = timings.all_runs" - ] - }, - { - "cell_type": "markdown", - "id": "16", - "metadata": { - "papermill": { - "duration": 0.005515, - "end_time": "2024-03-22T17:31:12.539068", - "exception": false, - "start_time": "2024-03-22T17:31:12.533553", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "## Polars scan_parquet" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17", - "metadata": { - "papermill": { - "duration": 4.800698, - "end_time": "2024-03-22T17:31:17.346813", - "exception": false, - "start_time": "2024-03-22T17:31:12.546115", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "tool = 'polars[lazy]'\n", - "fn = IO_FUNCS[tool]\n", - "timings = %timeit -o -q q7(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier)).collect()\n", - "results[tool] = timings.all_runs" - ] - }, - { - "cell_type": "markdown", - "id": "18", - "metadata": {}, - "source": [ - "## Save" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "19", - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" - ] - } - ], - "metadata": { - "kaggle": { - "accelerator": "none", - "dataSources": [ - { - "sourceId": 167796716, - "sourceType": "kernelVersion" - }, - { - "sourceId": 167796934, - "sourceType": "kernelVersion" - }, - { - "sourceId": 167796952, - "sourceType": "kernelVersion" - }, - { - "sourceId": 167796969, - "sourceType": "kernelVersion" - } - ], - "isGpuEnabled": false, - "isInternetEnabled": true, - "language": "python", - "sourceType": "notebook" - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - }, - "papermill": { - "default_parameters": {}, - "duration": 458.423327, - "end_time": "2024-03-22T17:31:18.077306", - "environment_variables": {}, - "exception": null, - "input_path": "__notebook__.ipynb", - "output_path": "__notebook__.ipynb", - "parameters": {}, - "start_time": "2024-03-22T17:23:39.653979", - "version": "2.5.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tpch/notebooks/q8/kernel-metadata.json b/tpch/notebooks/q8/kernel-metadata.json deleted file mode 100644 index 1c67f0d53..000000000 --- a/tpch/notebooks/q8/kernel-metadata.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "id": "marcogorelli/narwhals-tpch-q8-s2", - "title": "Narwhals TPCH Q8 S2", - "code_file": "execute.ipynb", - "language": "python", - "kernel_type": "notebook", - "is_private": "false", - "enable_gpu": "false", - "enable_tpu": "false", - "enable_internet": "true", - "dataset_sources": [], - "competition_sources": [], - "kernel_sources": ["marcogorelli/tpc-h-data-parquet-s-2"], - "model_sources": [] -} \ No newline at end of file diff --git a/tpch/notebooks/q9/execute.ipynb b/tpch/notebooks/q9/execute.ipynb index 86417e180..802799a01 100644 --- a/tpch/notebooks/q9/execute.ipynb +++ b/tpch/notebooks/q9/execute.ipynb @@ -15,7 +15,7 @@ }, "outputs": [], "source": [ - "!pip uninstall apache-beam -y && pip install -U pandas polars pyarrow narwhals " + "!pip uninstall apache-beam -y && pip install -U pandas polars pyarrow narwhals" ] }, { @@ -56,8 +56,10 @@ "outputs": [], "source": [ "from typing import Any\n", + "\n", "import narwhals as nw\n", "\n", + "\n", "def q9(\n", " part_ds_raw: Any,\n", " partsupp_ds_raw: Any,\n", @@ -66,7 +68,6 @@ " orders_ds_raw: Any,\n", " supplier_ds_raw: Any,\n", ") -> Any:\n", - "\n", " part_ds = nw.from_native(part_ds_raw)\n", " nation_ds = nw.from_native(nation_ds_raw)\n", " partsupp_ds = nw.from_native(partsupp_ds_raw)\n", @@ -91,7 +92,7 @@ " (\n", " nw.col(\"l_extendedprice\") * (1 - nw.col(\"l_discount\"))\n", " - nw.col(\"ps_supplycost\") * nw.col(\"l_quantity\")\n", - " ).alias(\"amount\")\n", + " ).alias(\"amount\"),\n", " )\n", " .group_by(\"nation\", \"o_year\")\n", " .agg(nw.sum(\"amount\").alias(\"sum_profit\"))\n", @@ -117,12 +118,12 @@ "outputs": [], "source": [ "dir_ = \"/kaggle/input/tpc-h-data-parquet-s-2/\"\n", - "nation = dir_ + 'nation.parquet'\n", - "lineitem = dir_ + 'lineitem.parquet'\n", - "orders = dir_ + 'orders.parquet'\n", - "supplier = dir_ + 'supplier.parquet'\n", - "part = dir_ + 'part.parquet'\n", - "partsupp = dir_ + 'partsupp.parquet'" + "nation = dir_ + \"nation.parquet\"\n", + "lineitem = dir_ + \"lineitem.parquet\"\n", + "orders = dir_ + \"orders.parquet\"\n", + "supplier = dir_ + \"supplier.parquet\"\n", + "part = dir_ + \"part.parquet\"\n", + "partsupp = dir_ + \"partsupp.parquet\"" ] }, { @@ -141,10 +142,12 @@ "outputs": [], "source": [ "IO_FUNCS = {\n", - " 'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),\n", - " 'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),\n", - " 'polars[eager]': lambda x: pl.read_parquet(x),\n", - " 'polars[lazy]': lambda x: pl.scan_parquet(x),\n", + " \"pandas\": lambda x: pd.read_parquet(x, engine=\"pyarrow\"),\n", + " \"pandas[pyarrow]\": lambda x: pd.read_parquet(\n", + " x, engine=\"pyarrow\", dtype_backend=\"pyarrow\"\n", + " ),\n", + " \"polars[eager]\": lambda x: pl.read_parquet(x),\n", + " \"polars[lazy]\": lambda x: pl.scan_parquet(x),\n", "}" ] }, @@ -188,7 +191,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas'\n", + "tool = \"pandas\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q9(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(orders), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -225,7 +228,7 @@ }, "outputs": [], "source": [ - "tool = 'pandas[pyarrow]'\n", + "tool = \"pandas[pyarrow]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q9(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(orders), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -262,7 +265,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[eager]'\n", + "tool = \"polars[eager]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q9(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(orders), fn(supplier))\n", "results[tool] = timings.all_runs" @@ -299,7 +302,7 @@ }, "outputs": [], "source": [ - "tool = 'polars[lazy]'\n", + "tool = \"polars[lazy]\"\n", "fn = IO_FUNCS[tool]\n", "timings = %timeit -o -q q9(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(orders), fn(supplier)).collect()\n", "results[tool] = timings.all_runs" @@ -319,8 +322,9 @@ "outputs": [], "source": [ "import json\n", - "with open('results.json', 'w') as fd:\n", - " json.dump(results, fd)\n" + "\n", + "with open(\"results.json\", \"w\") as fd:\n", + " json.dump(results, fd)" ] } ], diff --git a/tpch/queries/__init__.py b/tpch/queries/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tpch/queries/q1.py b/tpch/queries/q1.py new file mode 100644 index 000000000..de6157702 --- /dev/null +++ b/tpch/queries/q1.py @@ -0,0 +1,32 @@ +from datetime import datetime + +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query(lineitem: FrameT) -> FrameT: + var_1 = datetime(1998, 9, 2) + return ( + lineitem.filter(nw.col("l_shipdate") <= var_1) + .with_columns( + disc_price=nw.col("l_extendedprice") * (1 - nw.col("l_discount")), + charge=( + nw.col("l_extendedprice") + * (1.0 - nw.col("l_discount")) + * (1.0 + nw.col("l_tax")) + ), + ) + .group_by("l_returnflag", "l_linestatus") + .agg( + nw.sum("l_quantity").alias("sum_qty"), + nw.sum("l_extendedprice").alias("sum_base_price"), + nw.sum("disc_price").alias("sum_disc_price"), + nw.sum("charge").alias("sum_charge"), + nw.mean("l_quantity").alias("avg_qty"), + nw.mean("l_extendedprice").alias("avg_price"), + nw.mean("l_discount").alias("avg_disc"), + nw.len().alias("count_order"), + ) + .sort("l_returnflag", "l_linestatus") + ) diff --git a/tpch/queries/q10.py b/tpch/queries/q10.py new file mode 100644 index 000000000..486e4ba82 --- /dev/null +++ b/tpch/queries/q10.py @@ -0,0 +1,48 @@ +from datetime import datetime + +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query( + customer_ds: FrameT, + nation_ds: FrameT, + lineitem_ds: FrameT, + orders_ds: FrameT, +) -> FrameT: + var1 = datetime(1993, 10, 1) + var2 = datetime(1994, 1, 1) + + return ( + customer_ds.join(orders_ds, left_on="c_custkey", right_on="o_custkey") + .join(lineitem_ds, left_on="o_orderkey", right_on="l_orderkey") + .join(nation_ds, left_on="c_nationkey", right_on="n_nationkey") + .filter(nw.col("o_orderdate").is_between(var1, var2, closed="left")) + .filter(nw.col("l_returnflag") == "R") + .with_columns( + (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("revenue") + ) + .group_by( + "c_custkey", + "c_name", + "c_acctbal", + "c_phone", + "n_name", + "c_address", + "c_comment", + ) + .agg(nw.sum("revenue")) + .select( + "c_custkey", + "c_name", + "revenue", + "c_acctbal", + "n_name", + "c_address", + "c_phone", + "c_comment", + ) + .sort(by="revenue", descending=True) + .head(20) + ) diff --git a/tpch/queries/q11.py b/tpch/queries/q11.py new file mode 100644 index 000000000..d5b48b359 --- /dev/null +++ b/tpch/queries/q11.py @@ -0,0 +1,32 @@ +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query( + nation_ds: FrameT, + partsupp_ds: FrameT, + supplier_ds: FrameT, +) -> FrameT: + var1 = "GERMANY" + var2 = 0.0001 + + q1 = ( + partsupp_ds.join(supplier_ds, left_on="ps_suppkey", right_on="s_suppkey") + .join(nation_ds, left_on="s_nationkey", right_on="n_nationkey") + .filter(nw.col("n_name") == var1) + ) + q2 = q1.select( + (nw.col("ps_supplycost") * nw.col("ps_availqty")).sum().round(2).alias("tmp") + * var2 + ) + + return ( + q1.with_columns((nw.col("ps_supplycost") * nw.col("ps_availqty")).alias("value")) + .group_by("ps_partkey") + .agg(nw.sum("value")) + .join(q2, how="cross") + .filter(nw.col("value") > nw.col("tmp")) + .select("ps_partkey", "value") + .sort("value", descending=True) + ) diff --git a/tpch/queries/q12.py b/tpch/queries/q12.py new file mode 100644 index 000000000..ced775830 --- /dev/null +++ b/tpch/queries/q12.py @@ -0,0 +1,33 @@ +from datetime import datetime + +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query(line_item_ds: FrameT, orders_ds: FrameT) -> FrameT: + var1 = "MAIL" + var2 = "SHIP" + var3 = datetime(1994, 1, 1) + var4 = datetime(1995, 1, 1) + + return ( + orders_ds.join(line_item_ds, left_on="o_orderkey", right_on="l_orderkey") + .filter(nw.col("l_shipmode").is_in([var1, var2])) + .filter(nw.col("l_commitdate") < nw.col("l_receiptdate")) + .filter(nw.col("l_shipdate") < nw.col("l_commitdate")) + .filter(nw.col("l_receiptdate").is_between(var3, var4, closed="left")) + .with_columns( + nw.when(nw.col("o_orderpriority").is_in(["1-URGENT", "2-HIGH"])) + .then(1) + .otherwise(0) + .alias("high_line_count"), + nw.when(~nw.col("o_orderpriority").is_in(["1-URGENT", "2-HIGH"])) + .then(1) + .otherwise(0) + .alias("low_line_count"), + ) + .group_by("l_shipmode") + .agg(nw.col("high_line_count").sum(), nw.col("low_line_count").sum()) + .sort("l_shipmode") + ) diff --git a/tpch/queries/q13.py b/tpch/queries/q13.py new file mode 100644 index 000000000..adf57e5a2 --- /dev/null +++ b/tpch/queries/q13.py @@ -0,0 +1,19 @@ +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query(customer_ds: FrameT, orders_ds: FrameT) -> FrameT: + var1 = "special" + var2 = "requests" + + orders = orders_ds.filter(~nw.col("o_comment").str.contains(f"{var1}.*{var2}")) + return ( + customer_ds.join(orders, left_on="c_custkey", right_on="o_custkey", how="left") + .group_by("c_custkey") + .agg(nw.col("o_orderkey").count().alias("c_count")) + .group_by("c_count") + .agg(nw.len()) + .select(nw.col("c_count"), nw.col("len").alias("custdist")) + .sort(by=["custdist", "c_count"], descending=[True, True]) + ) diff --git a/tpch/queries/q14.py b/tpch/queries/q14.py new file mode 100644 index 000000000..f1ec6cbe3 --- /dev/null +++ b/tpch/queries/q14.py @@ -0,0 +1,27 @@ +from datetime import datetime + +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query(line_item_ds: FrameT, part_ds: FrameT) -> FrameT: + var1 = datetime(1995, 9, 1) + var2 = datetime(1995, 10, 1) + + return ( + line_item_ds.join(part_ds, left_on="l_partkey", right_on="p_partkey") + .filter(nw.col("l_shipdate").is_between(var1, var2, closed="left")) + .select( + ( + 100.00 + * nw.when(nw.col("p_type").str.contains("PROMO*")) + .then(nw.col("l_extendedprice") * (1 - nw.col("l_discount"))) + .otherwise(0) + .sum() + / (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).sum() + ) + .round(2) + .alias("promo_revenue") + ) + ) diff --git a/tpch/queries/q15.py b/tpch/queries/q15.py new file mode 100644 index 000000000..1ebae57d6 --- /dev/null +++ b/tpch/queries/q15.py @@ -0,0 +1,33 @@ +from datetime import datetime + +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query( + lineitem_ds: FrameT, + supplier_ds: FrameT, +) -> FrameT: + var1 = datetime(1996, 1, 1) + var2 = datetime(1996, 4, 1) + + revenue = ( + lineitem_ds.filter(nw.col("l_shipdate").is_between(var1, var2, closed="left")) + .with_columns( + (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias( + "total_revenue" + ) + ) + .group_by("l_suppkey") + .agg(nw.sum("total_revenue")) + .select(nw.col("l_suppkey").alias("supplier_no"), nw.col("total_revenue")) + ) + + return ( + supplier_ds.join(revenue, left_on="s_suppkey", right_on="supplier_no") + .filter(nw.col("total_revenue") == nw.col("total_revenue").max()) + .with_columns(nw.col("total_revenue").round(2)) + .select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") + .sort("s_suppkey") + ) diff --git a/tpch/queries/q16.py b/tpch/queries/q16.py new file mode 100644 index 000000000..d84b9aab5 --- /dev/null +++ b/tpch/queries/q16.py @@ -0,0 +1,26 @@ +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query(part_ds: FrameT, partsupp_ds: FrameT, supplier_ds: FrameT) -> FrameT: + var1 = "Brand#45" + + supplier = supplier_ds.filter( + nw.col("s_comment").str.contains(".*Customer.*Complaints.*") + ).select(nw.col("s_suppkey"), nw.col("s_suppkey").alias("ps_suppkey")) + + return ( + part_ds.join(partsupp_ds, left_on="p_partkey", right_on="ps_partkey") + .filter(nw.col("p_brand") != var1) + .filter(~nw.col("p_type").str.contains("MEDIUM POLISHED*")) + .filter(nw.col("p_size").is_in([49, 14, 23, 45, 19, 3, 36, 9])) + .join(supplier, left_on="ps_suppkey", right_on="s_suppkey", how="left") + .filter(nw.col("ps_suppkey_right").is_null()) + .group_by("p_brand", "p_type", "p_size") + .agg(nw.col("ps_suppkey").n_unique().alias("supplier_cnt")) + .sort( + by=["supplier_cnt", "p_brand", "p_type", "p_size"], + descending=[True, False, False, False], + ) + ) diff --git a/tpch/queries/q17.py b/tpch/queries/q17.py new file mode 100644 index 000000000..976f476f0 --- /dev/null +++ b/tpch/queries/q17.py @@ -0,0 +1,24 @@ +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query(lineitem_ds: FrameT, part_ds: FrameT) -> FrameT: + var1 = "Brand#23" + var2 = "MED BOX" + + query1 = ( + part_ds.filter(nw.col("p_brand") == var1) + .filter(nw.col("p_container") == var2) + .join(lineitem_ds, how="left", left_on="p_partkey", right_on="l_partkey") + ) + + return ( + query1.with_columns(l_quantity_times_point_2=nw.col("l_quantity") * 0.2) + .group_by("p_partkey") + .agg(nw.col("l_quantity_times_point_2").mean().alias("avg_quantity")) + .select(nw.col("p_partkey").alias("key"), nw.col("avg_quantity")) + .join(query1, left_on="key", right_on="p_partkey") + .filter(nw.col("l_quantity") < nw.col("avg_quantity")) + .select((nw.col("l_extendedprice").sum() / 7.0).round(2).alias("avg_yearly")) + ) diff --git a/tpch/queries/q18.py b/tpch/queries/q18.py new file mode 100644 index 000000000..d3d183176 --- /dev/null +++ b/tpch/queries/q18.py @@ -0,0 +1,31 @@ +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query(customer_ds: FrameT, lineitem_ds: FrameT, orders_ds: FrameT) -> FrameT: + var1 = 300 + + query1 = ( + lineitem_ds.group_by("l_orderkey") + .agg(nw.col("l_quantity").sum().alias("sum_quantity")) + .filter(nw.col("sum_quantity") > var1) + ) + + return ( + orders_ds.join(query1, left_on="o_orderkey", right_on="l_orderkey", how="semi") + .join(lineitem_ds, left_on="o_orderkey", right_on="l_orderkey") + .join(customer_ds, left_on="o_custkey", right_on="c_custkey") + .group_by("c_name", "o_custkey", "o_orderkey", "o_orderdate", "o_totalprice") + .agg(nw.col("l_quantity").sum().alias("col6")) + .select( + nw.col("c_name"), + nw.col("o_custkey").alias("c_custkey"), + nw.col("o_orderkey"), + nw.col("o_orderdate").alias("o_orderdat"), + nw.col("o_totalprice"), + nw.col("col6"), + ) + .sort(by=["o_totalprice", "o_orderdat"], descending=[True, False]) + .head(100) + ) diff --git a/tpch/queries/q19.py b/tpch/queries/q19.py new file mode 100644 index 000000000..bcab36e9a --- /dev/null +++ b/tpch/queries/q19.py @@ -0,0 +1,39 @@ +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query(lineitem_ds: FrameT, part_ds: FrameT) -> FrameT: + return ( + part_ds.join(lineitem_ds, left_on="p_partkey", right_on="l_partkey") + .filter(nw.col("l_shipmode").is_in(["AIR", "AIR REG"])) + .filter(nw.col("l_shipinstruct") == "DELIVER IN PERSON") + .filter( + ( + (nw.col("p_brand") == "Brand#12") + & nw.col("p_container").is_in(["SM CASE", "SM BOX", "SM PACK", "SM PKG"]) + & (nw.col("l_quantity").is_between(1, 11)) + & (nw.col("p_size").is_between(1, 5)) + ) + | ( + (nw.col("p_brand") == "Brand#23") + & nw.col("p_container").is_in( + ["MED BAG", "MED BOX", "MED PKG", "MED PACK"] + ) + & (nw.col("l_quantity").is_between(10, 20)) + & (nw.col("p_size").is_between(1, 10)) + ) + | ( + (nw.col("p_brand") == "Brand#34") + & nw.col("p_container").is_in(["LG CASE", "LG BOX", "LG PACK", "LG PKG"]) + & (nw.col("l_quantity").is_between(20, 30)) + & (nw.col("p_size").is_between(1, 15)) + ) + ) + .select( + (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))) + .sum() + .round(2) + .alias("revenue") + ) + ) diff --git a/tpch/queries/q2.py b/tpch/queries/q2.py new file mode 100644 index 000000000..0e9e90d09 --- /dev/null +++ b/tpch/queries/q2.py @@ -0,0 +1,54 @@ +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query( + region_ds: FrameT, + nation_ds: FrameT, + supplier_ds: FrameT, + part_ds: FrameT, + part_supp_ds: FrameT, +) -> FrameT: + var_1 = 15 + var_2 = "BRASS" + var_3 = "EUROPE" + + result_q2 = ( + part_ds.join(part_supp_ds, left_on="p_partkey", right_on="ps_partkey") + .join(supplier_ds, left_on="ps_suppkey", right_on="s_suppkey") + .join(nation_ds, left_on="s_nationkey", right_on="n_nationkey") + .join(region_ds, left_on="n_regionkey", right_on="r_regionkey") + .filter( + nw.col("p_size") == var_1, + nw.col("p_type").str.ends_with(var_2), + nw.col("r_name") == var_3, + ) + ) + + final_cols = [ + "s_acctbal", + "s_name", + "n_name", + "p_partkey", + "p_mfgr", + "s_address", + "s_phone", + "s_comment", + ] + + return ( + result_q2.group_by("p_partkey") + .agg(nw.col("ps_supplycost").min().alias("ps_supplycost")) + .join( + result_q2, + left_on=["p_partkey", "ps_supplycost"], + right_on=["p_partkey", "ps_supplycost"], + ) + .select(final_cols) + .sort( + ["s_acctbal", "n_name", "s_name", "p_partkey"], + descending=[True, False, False, False], + ) + .head(100) + ) diff --git a/tpch/queries/q20.py b/tpch/queries/q20.py new file mode 100644 index 000000000..d9014f7b8 --- /dev/null +++ b/tpch/queries/q20.py @@ -0,0 +1,43 @@ +from datetime import datetime + +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query( + part_ds: FrameT, + partsupp_ds: FrameT, + nation_ds: FrameT, + lineitem_ds: FrameT, + supplier_ds: FrameT, +) -> FrameT: + var1 = datetime(1994, 1, 1) + var2 = datetime(1995, 1, 1) + var3 = "CANADA" + var4 = "forest" + + query1 = ( + lineitem_ds.filter(nw.col("l_shipdate").is_between(var1, var2, closed="left")) + .group_by("l_partkey", "l_suppkey") + .agg((nw.col("l_quantity").sum()).alias("sum_quantity")) + .with_columns(sum_quantity=nw.col("sum_quantity") * 0.5) + ) + query2 = nation_ds.filter(nw.col("n_name") == var3) + query3 = supplier_ds.join(query2, left_on="s_nationkey", right_on="n_nationkey") + + return ( + part_ds.filter(nw.col("p_name").str.starts_with(var4)) + .select(nw.col("p_partkey").unique()) + .join(partsupp_ds, left_on="p_partkey", right_on="ps_partkey") + .join( + query1, + left_on=["ps_suppkey", "p_partkey"], + right_on=["l_suppkey", "l_partkey"], + ) + .filter(nw.col("ps_availqty") > nw.col("sum_quantity")) + .select(nw.col("ps_suppkey").unique()) + .join(query3, left_on="ps_suppkey", right_on="s_suppkey") + .select("s_name", "s_address") + .sort("s_name") + ) diff --git a/tpch/queries/q21.py b/tpch/queries/q21.py new file mode 100644 index 000000000..d10ff394f --- /dev/null +++ b/tpch/queries/q21.py @@ -0,0 +1,43 @@ +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query( + lineitem: FrameT, + nation: FrameT, + orders: FrameT, + supplier: FrameT, +) -> FrameT: + var1 = "SAUDI ARABIA" + + q1 = ( + lineitem.group_by("l_orderkey") + .agg(nw.len().alias("n_supp_by_order")) + .filter(nw.col("n_supp_by_order") > 1) + .join( + lineitem.filter(nw.col("l_receiptdate") > nw.col("l_commitdate")), + left_on="l_orderkey", + right_on="l_orderkey", + ) + ) + + return ( + q1.group_by("l_orderkey") + .agg(nw.len().alias("n_supp_by_order")) + .join( + q1, + left_on="l_orderkey", + right_on="l_orderkey", + ) + .join(supplier, left_on="l_suppkey", right_on="s_suppkey") + .join(nation, left_on="s_nationkey", right_on="n_nationkey") + .join(orders, left_on="l_orderkey", right_on="o_orderkey") + .filter(nw.col("n_supp_by_order") == 1) + .filter(nw.col("n_name") == var1) + .filter(nw.col("o_orderstatus") == "F") + .group_by("s_name") + .agg(nw.len().alias("numwait")) + .sort(by=["numwait", "s_name"], descending=[True, False]) + .head(100) + ) diff --git a/tpch/queries/q22.py b/tpch/queries/q22.py new file mode 100644 index 000000000..4738c6fd3 --- /dev/null +++ b/tpch/queries/q22.py @@ -0,0 +1,32 @@ +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query(customer_ds: FrameT, orders_ds: FrameT) -> FrameT: + q1 = ( + customer_ds.with_columns(nw.col("c_phone").str.slice(0, 2).alias("cntrycode")) + .filter(nw.col("cntrycode").str.contains("13|31|23|29|30|18|17")) + .select("c_acctbal", "c_custkey", "cntrycode") + ) + + q2 = q1.filter(nw.col("c_acctbal") > 0.0).select( + nw.col("c_acctbal").mean().alias("avg_acctbal") + ) + + q3 = orders_ds.select(nw.col("o_custkey").unique()).with_columns( + nw.col("o_custkey").alias("c_custkey") + ) + + return ( + q1.join(q3, left_on="c_custkey", right_on="c_custkey", how="left") + .filter(nw.col("o_custkey").is_null()) + .join(q2, how="cross") + .filter(nw.col("c_acctbal") > nw.col("avg_acctbal")) + .group_by("cntrycode") + .agg( + nw.col("c_acctbal").count().alias("numcust"), + nw.col("c_acctbal").sum().alias("totacctbal"), + ) + .sort("cntrycode") + ) diff --git a/tpch/queries/q3.py b/tpch/queries/q3.py new file mode 100644 index 000000000..04679bccb --- /dev/null +++ b/tpch/queries/q3.py @@ -0,0 +1,39 @@ +from datetime import datetime + +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query( + customer_ds: FrameT, + line_item_ds: FrameT, + orders_ds: FrameT, +) -> FrameT: + var_1 = var_2 = datetime(1995, 3, 15) + var_3 = "BUILDING" + + return ( + customer_ds.filter(nw.col("c_mktsegment") == var_3) + .join(orders_ds, left_on="c_custkey", right_on="o_custkey") + .join(line_item_ds, left_on="o_orderkey", right_on="l_orderkey") + .filter( + nw.col("o_orderdate") < var_2, + nw.col("l_shipdate") > var_1, + ) + .with_columns( + (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("revenue") + ) + .group_by(["o_orderkey", "o_orderdate", "o_shippriority"]) + .agg([nw.sum("revenue")]) + .select( + [ + nw.col("o_orderkey").alias("l_orderkey"), + "revenue", + "o_orderdate", + "o_shippriority", + ] + ) + .sort(by=["revenue", "o_orderdate"], descending=[True, False]) + .head(10) + ) diff --git a/tpch/queries/q4.py b/tpch/queries/q4.py new file mode 100644 index 000000000..a1b96be15 --- /dev/null +++ b/tpch/queries/q4.py @@ -0,0 +1,26 @@ +from datetime import datetime + +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query( + line_item_ds: FrameT, + orders_ds: FrameT, +) -> FrameT: + var_1 = datetime(1993, 7, 1) + var_2 = datetime(1993, 10, 1) + + return ( + line_item_ds.join(orders_ds, left_on="l_orderkey", right_on="o_orderkey") + .filter( + nw.col("o_orderdate").is_between(var_1, var_2, closed="left"), + nw.col("l_commitdate") < nw.col("l_receiptdate"), + ) + .unique(subset=["o_orderpriority", "l_orderkey"]) + .group_by("o_orderpriority") + .agg(nw.len().alias("order_count")) + .sort(by="o_orderpriority") + .with_columns(nw.col("order_count").cast(nw.Int64)) + ) diff --git a/tpch/queries/q5.py b/tpch/queries/q5.py new file mode 100644 index 000000000..2965868c9 --- /dev/null +++ b/tpch/queries/q5.py @@ -0,0 +1,40 @@ +from datetime import datetime + +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query( + region_ds: FrameT, + nation_ds: FrameT, + customer_ds: FrameT, + line_item_ds: FrameT, + orders_ds: FrameT, + supplier_ds: FrameT, +) -> FrameT: + var_1 = "ASIA" + var_2 = datetime(1994, 1, 1) + var_3 = datetime(1995, 1, 1) + + return ( + region_ds.join(nation_ds, left_on="r_regionkey", right_on="n_regionkey") + .join(customer_ds, left_on="n_nationkey", right_on="c_nationkey") + .join(orders_ds, left_on="c_custkey", right_on="o_custkey") + .join(line_item_ds, left_on="o_orderkey", right_on="l_orderkey") + .join( + supplier_ds, + left_on=["l_suppkey", "n_nationkey"], + right_on=["s_suppkey", "s_nationkey"], + ) + .filter( + nw.col("r_name") == var_1, + nw.col("o_orderdate").is_between(var_2, var_3, closed="left"), + ) + .with_columns( + (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("revenue") + ) + .group_by("n_name") + .agg([nw.sum("revenue")]) + .sort(by="revenue", descending=True) + ) diff --git a/tpch/queries/q6.py b/tpch/queries/q6.py new file mode 100644 index 000000000..67f0ac785 --- /dev/null +++ b/tpch/queries/q6.py @@ -0,0 +1,21 @@ +from datetime import datetime + +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query(line_item_ds: FrameT) -> FrameT: + var_1 = datetime(1994, 1, 1) + var_2 = datetime(1995, 1, 1) + var_3 = 24 + + return ( + line_item_ds.filter( + nw.col("l_shipdate").is_between(var_1, var_2, closed="left"), + nw.col("l_discount").is_between(0.05, 0.07), + nw.col("l_quantity") < var_3, + ) + .with_columns((nw.col("l_extendedprice") * nw.col("l_discount")).alias("revenue")) + .select(nw.sum("revenue")) + ) diff --git a/tpch/queries/q7.py b/tpch/queries/q7.py new file mode 100644 index 000000000..ec0946ac3 --- /dev/null +++ b/tpch/queries/q7.py @@ -0,0 +1,51 @@ +from datetime import datetime + +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query( + nation_ds: FrameT, + customer_ds: FrameT, + line_item_ds: FrameT, + orders_ds: FrameT, + supplier_ds: FrameT, +) -> FrameT: + n1 = nation_ds.filter(nw.col("n_name") == "FRANCE") + n2 = nation_ds.filter(nw.col("n_name") == "GERMANY") + + var_1 = datetime(1995, 1, 1) + var_2 = datetime(1996, 12, 31) + + df1 = ( + customer_ds.join(n1, left_on="c_nationkey", right_on="n_nationkey") + .join(orders_ds, left_on="c_custkey", right_on="o_custkey") + .rename({"n_name": "cust_nation"}) + .join(line_item_ds, left_on="o_orderkey", right_on="l_orderkey") + .join(supplier_ds, left_on="l_suppkey", right_on="s_suppkey") + .join(n2, left_on="s_nationkey", right_on="n_nationkey") + .rename({"n_name": "supp_nation"}) + ) + + df2 = ( + customer_ds.join(n2, left_on="c_nationkey", right_on="n_nationkey") + .join(orders_ds, left_on="c_custkey", right_on="o_custkey") + .rename({"n_name": "cust_nation"}) + .join(line_item_ds, left_on="o_orderkey", right_on="l_orderkey") + .join(supplier_ds, left_on="l_suppkey", right_on="s_suppkey") + .join(n1, left_on="s_nationkey", right_on="n_nationkey") + .rename({"n_name": "supp_nation"}) + ) + + return ( + nw.concat([df1, df2]) + .filter(nw.col("l_shipdate").is_between(var_1, var_2)) + .with_columns( + (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("volume") + ) + .with_columns(nw.col("l_shipdate").dt.year().alias("l_year")) + .group_by("supp_nation", "cust_nation", "l_year") + .agg(nw.sum("volume").alias("revenue")) + .sort(by=["supp_nation", "cust_nation", "l_year"]) + ) diff --git a/tpch/queries/q8.py b/tpch/queries/q8.py new file mode 100644 index 000000000..ac3fa4baf --- /dev/null +++ b/tpch/queries/q8.py @@ -0,0 +1,52 @@ +from datetime import date + +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query( + part_ds: FrameT, + supplier_ds: FrameT, + line_item_ds: FrameT, + orders_ds: FrameT, + customer_ds: FrameT, + nation_ds: FrameT, + region_ds: FrameT, +) -> FrameT: + nation = "BRAZIL" + region = "AMERICA" + type = "ECONOMY ANODIZED STEEL" + date1 = date(1995, 1, 1) + date2 = date(1996, 12, 31) + + n1 = nation_ds.select("n_nationkey", "n_regionkey") + n2 = nation_ds.select("n_nationkey", "n_name") + + return ( + part_ds.join(line_item_ds, left_on="p_partkey", right_on="l_partkey") + .join(supplier_ds, left_on="l_suppkey", right_on="s_suppkey") + .join(orders_ds, left_on="l_orderkey", right_on="o_orderkey") + .join(customer_ds, left_on="o_custkey", right_on="c_custkey") + .join(n1, left_on="c_nationkey", right_on="n_nationkey") + .join(region_ds, left_on="n_regionkey", right_on="r_regionkey") + .filter(nw.col("r_name") == region) + .join(n2, left_on="s_nationkey", right_on="n_nationkey") + .filter(nw.col("o_orderdate").is_between(date1, date2)) + .filter(nw.col("p_type") == type) + .select( + nw.col("o_orderdate").dt.year().alias("o_year"), + (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("volume"), + nw.col("n_name").alias("nation"), + ) + .with_columns( + nw.when(nw.col("nation") == nation) + .then(nw.col("volume")) + .otherwise(0) + .alias("_tmp") + ) + .group_by("o_year") + .agg(_tmp_sum=nw.sum("_tmp"), volume_sum=nw.sum("volume")) + .select("o_year", mkt_share=nw.col("_tmp_sum") / nw.col("volume_sum")) + .sort("o_year") + ) diff --git a/tpch/queries/q9.py b/tpch/queries/q9.py new file mode 100644 index 000000000..09dff4787 --- /dev/null +++ b/tpch/queries/q9.py @@ -0,0 +1,36 @@ +import narwhals as nw +from narwhals.typing import FrameT + + +@nw.narwhalify +def query( + part_ds: FrameT, + partsupp_ds: FrameT, + nation_ds: FrameT, + lineitem_ds: FrameT, + orders_ds: FrameT, + supplier_ds: FrameT, +) -> FrameT: + return ( + part_ds.join(partsupp_ds, left_on="p_partkey", right_on="ps_partkey") + .join(supplier_ds, left_on="ps_suppkey", right_on="s_suppkey") + .join( + lineitem_ds, + left_on=["p_partkey", "ps_suppkey"], + right_on=["l_partkey", "l_suppkey"], + ) + .join(orders_ds, left_on="l_orderkey", right_on="o_orderkey") + .join(nation_ds, left_on="s_nationkey", right_on="n_nationkey") + .filter(nw.col("p_name").str.contains("green")) + .select( + nw.col("n_name").alias("nation"), + nw.col("o_orderdate").dt.year().alias("o_year"), + ( + nw.col("l_extendedprice") * (1 - nw.col("l_discount")) + - nw.col("ps_supplycost") * nw.col("l_quantity") + ).alias("amount"), + ) + .group_by("nation", "o_year") + .agg(nw.sum("amount").alias("sum_profit")) + .sort(by=["nation", "o_year"], descending=[False, True]) + ) diff --git a/tpch/tests/__init__.py b/tpch/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tpch/tests/test_queries.py b/tpch/tests/test_queries.py new file mode 100644 index 000000000..35909b683 --- /dev/null +++ b/tpch/tests/test_queries.py @@ -0,0 +1,21 @@ +import subprocess +import sys +from pathlib import Path + + +def test_execute_scripts() -> None: + root = Path(__file__).resolve().parent.parent + # directory containing all the queries + execute_dir = root / "execute" + + for script_path in execute_dir.glob("q[1-9]*.py"): + print(f"executing query {script_path.stem}") # noqa: T201 + result = subprocess.run( # noqa: S603 + [sys.executable, "-m", f"execute.{script_path.stem}"], + capture_output=True, + text=True, + check=False, + ) + assert ( + result.returncode == 0 + ), f"Script {script_path} failed with error: {result.stderr}" diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index f6e5303c4..ec599def5 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -162,6 +162,7 @@ "value_counts", "zip_with", "item", + "scatter", } ) ): diff --git a/utils/generate_backend_completeness.py b/utils/generate_backend_completeness.py index 537701872..b2cb9df21 100644 --- a/utils/generate_backend_completeness.py +++ b/utils/generate_backend_completeness.py @@ -36,7 +36,7 @@ class Backend(NamedTuple): Backend(name="dask", module="_dask", type_=BackendType.LAZY), ] -EXCLUDE_CLASSES = {"BaseFrame"} +EXCLUDE_CLASSES = {"BaseFrame", "Then", "When"} def get_class_methods(kls: type[Any]) -> list[str]: @@ -113,6 +113,7 @@ def get_backend_completeness_table() -> None: tbl_hide_column_data_types=True, tbl_hide_dataframe_shape=True, set_tbl_rows=results.shape[0], + set_tbl_width_chars=1_000, ): table = str(results)