diff --git a/.github/workflows/check_docs_build.yml b/.github/workflows/check_docs_build.yml index 3b418baff..c59e67f46 100644 --- a/.github/workflows/check_docs_build.yml +++ b/.github/workflows/check_docs_build.yml @@ -18,8 +18,12 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install uv (Unix) - run: curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: "true" + cache-suffix: ${{ matrix.python-version }} + cache-dependency-glob: "**requirements*.txt" - name: install-reqs run: uv pip install --upgrade tox virtualenv setuptools pip -r requirements-dev.txt --system - name: install-docs-reqs diff --git a/.github/workflows/check_tpch_queries.yml b/.github/workflows/check_tpch_queries.yml index 46dd5df20..509ebb95b 100644 --- a/.github/workflows/check_tpch_queries.yml +++ b/.github/workflows/check_tpch_queries.yml @@ -18,8 +18,12 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install uv - run: curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: "true" + cache-suffix: ${{ matrix.python-version }} + cache-dependency-glob: "**requirements*.txt" - name: install-reqs run: uv pip install --upgrade -r requirements-dev.txt --system - name: local-install @@ -27,4 +31,4 @@ jobs: - name: generate-data run: cd tpch && python generate_data.py - name: tpch-tests - run: cd tpch && pytest tests \ No newline at end of file + run: cd tpch && pytest tests diff --git a/.github/workflows/downstream_tests.yml b/.github/workflows/downstream_tests.yml index c733e348d..26e9edf7a 100644 --- a/.github/workflows/downstream_tests.yml +++ b/.github/workflows/downstream_tests.yml @@ -18,8 +18,12 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install uv (Unix) - run: curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: "true" + cache-suffix: ${{ matrix.python-version }} + cache-dependency-glob: "**requirements*.txt" - name: clone-altair run: | git clone https://github.com/vega/altair.git --depth=1 @@ -58,8 +62,12 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install uv (Unix) - run: curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: "true" + cache-suffix: ${{ matrix.python-version }} + cache-dependency-glob: "**requirements*.txt" - name: clone-scikit-lego run: git clone https://github.com/koaning/scikit-lego.git --depth 1 - name: install-basics diff --git a/.github/workflows/extremes.yml b/.github/workflows/extremes.yml index 7e1a5586e..cec9d32d7 100644 --- a/.github/workflows/extremes.yml +++ b/.github/workflows/extremes.yml @@ -18,8 +18,12 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install uv (Unix) - run: curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: "true" + cache-suffix: ${{ matrix.python-version }} + cache-dependency-glob: "**requirements*.txt" - name: install-minimum-versions run: uv pip install tox virtualenv setuptools pandas==0.25.3 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system - name: install-reqs @@ -41,8 +45,12 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install uv (Unix) - run: curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: "true" + cache-suffix: ${{ matrix.python-version }} + cache-dependency-glob: "**requirements*.txt" - name: install-minimum-versions run: uv pip install tox virtualenv setuptools pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system - name: install-reqs @@ -66,8 +74,12 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install uv (Unix) - run: curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: "true" + cache-suffix: ${{ matrix.python-version }} + cache-dependency-glob: "**requirements*.txt" - name: install-minimum-versions run: uv pip install tox virtualenv setuptools pandas==2.0.3 polars==0.20.8 numpy==1.24.4 pyarrow==14.0.0 scipy==1.8.0 scikit-learn==1.3.0 dask[dataframe]==2024.7 tzdata --system - name: install-reqs @@ -79,22 +91,34 @@ jobs: - name: Run doctests run: pytest narwhals --doctest-modules - pandas-nightly-and-dask: + nightlies: strategy: matrix: - python-version: ["3.12"] + python-version: ["3.11"] os: [ubuntu-latest] - + if: github.event.pull_request.head.repo.full_name == github.repository runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install uv (Unix) - run: curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: "true" + cache-suffix: ${{ matrix.python-version }} + cache-dependency-glob: "**requirements*.txt" + - name: install-kaggle + run: uv pip install kaggle --system + - name: Download Kaggle notebook artifact + env: + KAGGLE_USERNAME: ${{ secrets.KAGGLE_USERNAME }} + KAGGLE_KEY: ${{ secrets.KAGGLE_KEY }} + run: | + kaggle kernels output "marcogorelli/variable-brink-glacier" - name: install-polars - run: uv pip install polars --system + run: python -m pip install *.whl - name: install-reqs run: uv pip install --upgrade tox virtualenv setuptools pip -r requirements-dev.txt --system - name: uninstall pyarrow @@ -119,43 +143,4 @@ jobs: - name: Run doctests run: pytest narwhals --doctest-modules - # polars-nightly: - # if: github.ref == 'refs/heads/main' - # strategy: - # matrix: - # python-version: ["3.12"] - # os: [ubuntu-latest] - # runs-on: ${{ matrix.os }} - # steps: - # - uses: actions/checkout@v4 - # - uses: actions/setup-python@v5 - # with: - # python-version: ${{ matrix.python-version }} - # - name: Cache multiple paths - # uses: actions/cache@v4 - # with: - # path: | - # ~/.cache/pip - # $RUNNER_TOOL_CACHE/Python/* - # ~\AppData\Local\pip\Cache - # key: ${{ runner.os }}-build-${{ matrix.python-version }} - # - name: install-kaggle - # run: python -m pip install kaggle - # - name: Download Kaggle notebook artifact - # env: - # KAGGLE_USERNAME: ${{ secrets.KAGGLE_USERNAME }} - # KAGGLE_KEY: ${{ secrets.KAGGLE_KEY }} - # run: kaggle kernels output marcogorelli/polars-nightly - # - name: install-reqs - # run: python -m pip install --upgrade tox virtualenv setuptools pip -r requirements-dev.txt - # - name: uninstall polars - # run: python -m pip uninstall polars -y - # - name: install-modin-pandas - # run: pip install modin[dask] pandas - # - name: install-polars-nightly - # run: python -m pip install *.whl - # - name: Run pytest - # run: pytest tests --cov=narwhals --cov=tests --cov-fail-under=50 - # - name: Run doctests - # run: pytest narwhals --doctest-modules diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 265442e9f..39b5c91b3 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -18,12 +18,12 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install uv (Unix) - if: runner.os != 'Windows' - run: curl -LsSf https://astral.sh/uv/install.sh | sh - - name: Install uv (Windows) - if: runner.os == 'Windows' - run: powershell -c "irm https://astral.sh/uv/install.ps1 | iex" + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: "true" + cache-suffix: ${{ matrix.python-version }} + cache-dependency-glob: "**requirements*.txt" - name: install-reqs run: uv pip install --upgrade tox virtualenv setuptools -r requirements-dev.txt ibis-framework[duckdb] --system - name: show-deps @@ -45,8 +45,12 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install uv (Windows) - run: powershell -c "irm https://astral.sh/uv/install.ps1 | iex" + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: "true" + cache-suffix: ${{ matrix.python-version }} + cache-dependency-glob: "**requirements*.txt" - name: install-reqs run: uv pip install --upgrade tox virtualenv setuptools -r requirements-dev.txt --system - name: install-modin @@ -70,8 +74,12 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install uv (Unix) - run: curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: "true" + cache-suffix: ${{ matrix.python-version }} + cache-dependency-glob: "**requirements*.txt" - name: install-reqs run: uv pip install --upgrade tox virtualenv setuptools -r requirements-dev.txt --system - name: install-modin diff --git a/.github/workflows/random_ci_pytest.yml b/.github/workflows/random_ci_pytest.yml index e25bdcb68..b029aba3c 100644 --- a/.github/workflows/random_ci_pytest.yml +++ b/.github/workflows/random_ci_pytest.yml @@ -16,8 +16,12 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install uv (Unix) - run: curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: "true" + cache-suffix: ${{ matrix.python-version }} + cache-dependency-glob: "**requirements*.txt" - name: install package run: uv pip install -e . --system - name: generate-random-versions diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f3a68e7a0..04e41ea30 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.6.3' + rev: 'v0.6.4' hooks: # Run the formatter. - id: ruff-format diff --git a/README.md b/README.md index 74630fd03..29623920a 100644 --- a/README.md +++ b/README.md @@ -95,9 +95,6 @@ See the [tutorial](https://narwhals-dev.github.io/narwhals/basics/dataframe/) fo If you said yes to both, we'd love to hear from you! -**Note**: You might suspect that this is a secret ploy to infiltrate the Polars API everywhere. -Indeed, you may suspect that. - ## Sponsors and institutional partners Narwhals is 100% independent, community-driven, and community-owned. diff --git a/docs/api-reference/dataframe.md b/docs/api-reference/dataframe.md index f78b4e3da..fe0c4025d 100644 --- a/docs/api-reference/dataframe.md +++ b/docs/api-reference/dataframe.md @@ -29,6 +29,7 @@ - rename - row - rows + - sample - schema - select - shape diff --git a/docs/basics/complete_example.md b/docs/basics/complete_example.md index 1e2cfe30d..d3b4ecfe4 100644 --- a/docs/basics/complete_example.md +++ b/docs/basics/complete_example.md @@ -10,10 +10,6 @@ We'll need to write two methods: - `transform`: scale a given dataset with the mean and standard deviations calculated during `fit`. -The `fit` method is a bit complicated, so let's start with `transform`. -Suppose we've already calculated the mean and standard deviation of each column, and have -stored them in attributes `self.means` and `self.std_devs`. - ## Fit method Unlike the `transform` method, which we'll write below, `fit` cannot stay lazy, diff --git a/docs/installation.md b/docs/installation.md index 5a49dba8f..617606817 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -11,6 +11,6 @@ Then, if you start the Python REPL and see the following: ```python >>> import narwhals >>> narwhals.__version__ -'1.7.0' +'1.8.1' ``` then installation worked correctly! diff --git a/narwhals/__init__.py b/narwhals/__init__.py index d76ad2262..3c8b7b776 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -53,7 +53,7 @@ from narwhals.utils import maybe_get_index from narwhals.utils import maybe_set_index -__version__ = "1.7.0" +__version__ = "1.8.1" __all__ = [ "dependencies", diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index f409ef735..428e83e3b 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -9,6 +9,7 @@ from typing import overload from narwhals._arrow.utils import broadcast_series +from narwhals._arrow.utils import convert_slice_to_nparray from narwhals._arrow.utils import translate_dtype from narwhals._arrow.utils import validate_dataframe_comparand from narwhals._expression_parsing import evaluate_into_exprs @@ -126,7 +127,8 @@ def __getitem__( | slice | Sequence[int] | Sequence[str] - | tuple[Sequence[int], str | int], + | tuple[Sequence[int], str | int] + | tuple[slice, str | int], ) -> ArrowSeries | ArrowDataFrame: if isinstance(item, str): from narwhals._arrow.series import ArrowSeries @@ -144,7 +146,10 @@ def __getitem__( if item[0] == slice(None): selected_rows = self._native_frame else: - selected_rows = self._native_frame.take(item[0]) + range_ = convert_slice_to_nparray( + num_rows=len(self._native_frame), rows_slice=item[0] + ) + selected_rows = self._native_frame.take(range_) return self._from_native_frame(selected_rows.select(item[1])) @@ -174,13 +179,24 @@ def __getitem__( ) msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover raise TypeError(msg) # pragma: no cover - from narwhals._arrow.series import ArrowSeries # PyArrow columns are always strings col_name = item[1] if isinstance(item[1], str) else self.columns[item[1]] + if isinstance(item[0], str): # pragma: no cover + msg = "Can not slice with tuple with the first element as a str" + raise TypeError(msg) + if (isinstance(item[0], slice)) and (item[0] == slice(None)): + return ArrowSeries( + self._native_frame[col_name], + name=col_name, + backend_version=self._backend_version, + ) + range_ = convert_slice_to_nparray( + num_rows=len(self._native_frame), rows_slice=item[0] + ) return ArrowSeries( - self._native_frame[col_name].take(item[0]), + self._native_frame[col_name].take(range_), name=col_name, backend_version=self._backend_version, ) @@ -572,3 +588,25 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: def to_arrow(self: Self) -> Any: return self._native_frame + + def sample( + self: Self, + n: int | None = None, + *, + fraction: float | None = None, + with_replacement: bool = False, + seed: int | None = None, + ) -> Self: + import numpy as np # ignore-banned-import + import pyarrow.compute as pc # ignore-banned-import() + + frame = self._native_frame + num_rows = len(self) + if n is None and fraction is not None: + n = int(num_rows * fraction) + + rng = np.random.default_rng(seed=seed) + idx = np.arange(0, num_rows) + mask = rng.choice(idx, size=n, replace=with_replacement) + + return self._from_native_frame(pc.take(frame, mask)) diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 24e4fe5c5..b324be2b1 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -249,12 +249,18 @@ def arg_true(self) -> Self: def sample( self: Self, n: int | None = None, - fraction: float | None = None, *, + fraction: float | None = None, with_replacement: bool = False, + seed: int | None = None, ) -> Self: return reuse_series_implementation( - self, "sample", n=n, fraction=fraction, with_replacement=with_replacement + self, + "sample", + n=n, + fraction=fraction, + with_replacement=with_replacement, + seed=seed, ) def fill_null(self: Self, value: Any) -> Self: diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 73390fdd3..4f53a3f00 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -534,9 +534,10 @@ def zip_with(self: Self, mask: Self, other: Self) -> Self: def sample( self: Self, n: int | None = None, - fraction: float | None = None, *, + fraction: float | None = None, with_replacement: bool = False, + seed: int | None = None, ) -> Self: import numpy as np # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import() @@ -547,8 +548,10 @@ def sample( if n is None and fraction is not None: n = int(num_rows * fraction) + rng = np.random.default_rng(seed=seed) idx = np.arange(0, num_rows) - mask = np.random.choice(idx, size=n, replace=with_replacement) + mask = rng.choice(idx, size=n, replace=with_replacement) + return self._from_native_series(pc.take(ser, mask)) def fill_null(self: Self, value: Any) -> Self: @@ -928,6 +931,7 @@ def get_categories(self) -> ArrowSeries: ca = self._arrow_series._native_series # TODO(Unassigned): this looks potentially expensive - is there no better way? + # https://github.com/narwhals-dev/narwhals/issues/464 out = pa.chunked_array( [pa.concat_arrays([x.dictionary for x in ca.chunks]).unique()] ) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index b8294839c..a2a45586b 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Sequence from narwhals import dtypes from narwhals.utils import isinstance_or_issubclass @@ -11,7 +12,7 @@ def translate_dtype(dtype: Any) -> dtypes.DType: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa # ignore-banned-import if pa.types.is_int64(dtype): return dtypes.Int64() @@ -55,7 +56,7 @@ def translate_dtype(dtype: Any) -> dtypes.DType: def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa # ignore-banned-import from narwhals import dtypes @@ -84,8 +85,6 @@ def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any: if isinstance_or_issubclass(dtype, dtypes.Boolean): return pa.bool_() if isinstance_or_issubclass(dtype, dtypes.Categorical): - # TODO(Unassigned): what should the key be? let's keep it consistent - # with Polars for now return pa.dictionary(pa.uint32(), pa.string()) if isinstance_or_issubclass(dtype, dtypes.Datetime): # Use Polars' default @@ -142,7 +141,7 @@ def validate_dataframe_comparand( return NotImplemented if isinstance(other, ArrowSeries): if len(other) == 1: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa # ignore-banned-import value = other.item() if backend_version < (13,) and hasattr(value, "as_py"): # pragma: no cover @@ -159,7 +158,7 @@ def horizontal_concat(dfs: list[Any]) -> Any: Should be in namespace. """ - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa # ignore-banned-import if not dfs: msg = "No dataframes to concatenate" # pragma: no cover @@ -192,7 +191,7 @@ def vertical_concat(dfs: list[Any]) -> Any: msg = "unable to vstack, column names don't match" raise TypeError(msg) - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa # ignore-banned-import return pa.concat_tables(dfs).combine_chunks() @@ -200,8 +199,8 @@ def vertical_concat(dfs: list[Any]) -> Any: def floordiv_compat(left: Any, right: Any) -> Any: # The following lines are adapted from pandas' pyarrow implementation. # Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154 - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa # ignore-banned-import + import pyarrow.compute as pc # ignore-banned-import if isinstance(left, (int, float)): left = pa.scalar(left) @@ -239,8 +238,8 @@ def floordiv_compat(left: Any, right: Any) -> Any: def cast_for_truediv(arrow_array: Any, pa_object: Any) -> tuple[Any, Any]: # Lifted from: # https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122 - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa # ignore-banned-import + import pyarrow.compute as pc # ignore-banned-import # Ensure int / int -> float mirroring Python/Numpy behavior # as pc.divide_checked(int, int) -> int @@ -262,7 +261,7 @@ def broadcast_series(series: list[ArrowSeries]) -> list[Any]: if fast_path: return [s._native_series for s in series] - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa # ignore-banned-import reshaped = [] for s, length in zip(series, lengths): @@ -276,3 +275,14 @@ def broadcast_series(series: list[ArrowSeries]) -> list[Any]: reshaped.append(s_native) return reshaped + + +def convert_slice_to_nparray( + num_rows: int, rows_slice: slice | int | Sequence[int] +) -> Any: + import numpy as np # ignore-banned-import + + if isinstance(rows_slice, slice): + return np.arange(num_rows)[rows_slice] + else: + return rows_slice diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 180a897bd..ac10ac2b8 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -213,6 +213,10 @@ def join( right_on: str | list[str] | None, suffix: str, ) -> Self: + if isinstance(left_on, str): + left_on = [left_on] + if isinstance(right_on, str): + right_on = [right_on] if how == "cross": key_token = generate_unique_token( n_bytes=8, columns=[*self.columns, *other.columns] diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index f08af590c..bb5f8ddcb 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -103,13 +103,11 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: if root_names is not None and isinstance(arg, self.__class__): if arg._root_names is not None: root_names.extend(arg._root_names) - else: # pragma: no cover - # TODO(unassigned): increase coverage + else: root_names = None output_names = None break - elif root_names is None: # pragma: no cover - # TODO(unassigned): increase coverage + elif root_names is None: output_names = None break @@ -431,6 +429,11 @@ def round(self, decimals: int) -> Self: returns_scalar=False, ) + def unique(self) -> NoReturn: + # We can't (yet?) allow methods which modify the index + msg = "`Expr.unique` is not supported for the Dask backend. Please use `LazyFrame.unique` instead." + raise NotImplementedError(msg) + def drop_nulls(self) -> NoReturn: # We can't (yet?) allow methods which modify the index msg = "`Expr.drop_nulls` is not supported for the Dask backend. Please use `LazyFrame.drop_nulls` instead." diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 71a659998..620670696 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -693,3 +693,17 @@ def to_arrow(self: Self) -> Any: import pyarrow as pa # ignore-banned-import() return pa.Table.from_pandas(self._native_frame) + + def sample( + self: Self, + n: int | None = None, + *, + fraction: float | None = None, + with_replacement: bool = False, + seed: int | None = None, + ) -> Self: + return self._from_native_frame( + self._native_frame.sample( + n=n, frac=fraction, replace=with_replacement, random_state=seed + ) + ) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 74a2ee31d..409d1ab09 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -254,14 +254,20 @@ def shift(self, n: int) -> Self: return reuse_series_implementation(self, "shift", n=n) def sample( - self, + self: Self, n: int | None = None, - fraction: float | None = None, *, + fraction: float | None = None, with_replacement: bool = False, + seed: int | None = None, ) -> Self: return reuse_series_implementation( - self, "sample", n=n, fraction=fraction, with_replacement=with_replacement + self, + "sample", + n=n, + fraction=fraction, + with_replacement=with_replacement, + seed=seed, ) def alias(self, name: str) -> Self: diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 8288be263..14d7a128e 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -436,15 +436,16 @@ def n_unique(self) -> int: return ser.nunique(dropna=False) # type: ignore[no-any-return] def sample( - self, + self: Self, n: int | None = None, - fraction: float | None = None, *, + fraction: float | None = None, with_replacement: bool = False, - ) -> PandasLikeSeries: + seed: int | None = None, + ) -> Self: ser = self._native_series return self._from_native_series( - ser.sample(n=n, frac=fraction, replace=with_replacement) + ser.sample(n=n, frac=fraction, replace=with_replacement, random_state=seed) ) def abs(self) -> PandasLikeSeries: diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 1b91f0910..b51d53baa 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -602,11 +602,15 @@ def __getitem__(self, item: tuple[slice, Sequence[int]]) -> Self: ... @overload def __getitem__(self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap] @overload + def __getitem__(self, item: tuple[slice, str]) -> Series: ... # type: ignore[overload-overlap] + @overload def __getitem__(self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ... @overload def __getitem__(self, item: tuple[slice, Sequence[str]]) -> Self: ... @overload def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap] + @overload + def __getitem__(self, item: tuple[slice, int]) -> Series: ... # type: ignore[overload-overlap] @overload def __getitem__(self, item: Sequence[int]) -> Self: ... @@ -627,6 +631,7 @@ def __getitem__( | Sequence[int] | Sequence[str] | tuple[Sequence[int], str | int] + | tuple[slice, str | int] | tuple[slice | Sequence[int], Sequence[int] | Sequence[str] | slice], ) -> Series | Self: """ @@ -801,6 +806,9 @@ def row(self, index: int) -> tuple[Any, ...]: Arguments: index: Row number. + Notes: + cuDF doesn't support this method. + Examples: >>> import narwhals as nw >>> import pandas as pd @@ -1137,6 +1145,9 @@ def iter_rows( internally while iterating over the data. See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.iter_rows.html + Notes: + cuDF doesn't support this method. + Examples: >>> import pandas as pd >>> import polars as pl @@ -2453,6 +2464,66 @@ def to_arrow(self: Self) -> pa.Table: """ return self._compliant_frame.to_arrow() + def sample( + self: Self, + n: int | None = None, + *, + fraction: float | None = None, + with_replacement: bool = False, + seed: int | None = None, + ) -> Self: + r""" + Sample from this DataFrame. + + Arguments: + n: Number of items to return. Cannot be used with fraction. + fraction: Fraction of items to return. Cannot be used with n. + with_replacement: Allow values to be sampled more than once. + seed: Seed for the random number generator. If set to None (default), a random + seed is generated for each sample operation. + + Notes: + The results may not be consistent across libraries. + + Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> data = {"a": [1, 2, 3, 4], "b": ["x", "y", "x", "y"]} + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.DataFrame(data) + + We define a library agnostic function: + + >>> @nw.narwhalify + ... def func(df): + ... return df.sample(n=2, seed=123) + + We can then pass either pandas or Polars to `func`: + >>> func(df_pd) + a b + 3 4 y + 0 1 x + >>> func(df_pl) + shape: (2, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ str │ + ╞═════╪═════╡ + │ 2 ┆ y │ + │ 3 ┆ x │ + └─────┴─────┘ + + As you can see, by using the same seed, the result will be consistent within + the same backend, but not necessarely across different backends. + """ + return self._from_compliant_dataframe( + self._compliant_frame.sample( + n=n, fraction=fraction, with_replacement=with_replacement, seed=seed + ) + ) + class LazyFrame(BaseFrame[FrameT]): """ @@ -3851,34 +3922,35 @@ def clone(self) -> Self: r""" Create a copy of this DataFrame. - >>> import narwhals as nw - >>> import pandas as pd - >>> import polars as pl - >>> data = {"a": [1, 2], "b": [3, 4]} - >>> df_pd = pd.DataFrame(data) - >>> df_pl = pl.LazyFrame(data) - - Let's define a dataframe-agnostic function in which we copy the DataFrame: - - >>> @nw.narwhalify - ... def func(df): - ... return df.clone() - - >>> func(df_pd) - a b - 0 1 3 - 1 2 4 - - >>> func(df_pl).collect() - shape: (2, 2) - ┌─────┬─────┐ - │ a ┆ b │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪═════╡ - │ 1 ┆ 3 │ - │ 2 ┆ 4 │ - └─────┴─────┘ + Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> data = {"a": [1, 2], "b": [3, 4]} + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.LazyFrame(data) + + Let's define a dataframe-agnostic function in which we copy the DataFrame: + + >>> @nw.narwhalify + ... def func(df): + ... return df.clone() + + >>> func(df_pd) + a b + 0 1 3 + 1 2 4 + + >>> func(df_pl).collect() + shape: (2, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 1 ┆ 3 │ + │ 2 ┆ 4 │ + └─────┴─────┘ """ return super().clone() diff --git a/narwhals/expr.py b/narwhals/expr.py index a8407915a..13c50dec7 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -1228,21 +1228,22 @@ def drop_nulls(self) -> Self: return self.__class__(lambda plx: self._call(plx).drop_nulls()) def sample( - self, + self: Self, n: int | None = None, - fraction: float | None = None, *, + fraction: float | None = None, with_replacement: bool = False, + seed: int | None = None, ) -> Self: """ Sample randomly from this expression. Arguments: n: Number of items to return. Cannot be used with fraction. - fraction: Fraction of items to return. Cannot be used with n. - with_replacement: Allow values to be sampled more than once. + seed: Seed for the random number generator. If set to None (default), a random + seed is generated for each sample operation. Examples: >>> import narwhals as nw @@ -1279,7 +1280,7 @@ def sample( """ return self.__class__( lambda plx: self._call(plx).sample( - n, fraction=fraction, with_replacement=with_replacement + n, fraction=fraction, with_replacement=with_replacement, seed=seed ) ) @@ -1932,25 +1933,23 @@ def mode(self: Self) -> Self: >>> @nw.narwhalify ... def func(df): - ... return df.select(nw.col("a", "b").mode()).sort("a", "b") + ... return df.select(nw.col("a").mode()).sort("a") We can then pass either pandas or Polars to `func`: >>> func(df_pd) - a b - 0 1 1 - 1 1 2 + a + 0 1 >>> func(df_pl) - shape: (2, 2) - ┌─────┬─────┐ - │ a ┆ b │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪═════╡ - │ 1 ┆ 1 │ - │ 1 ┆ 2 │ - └─────┴─────┘ + shape: (1, 1) + ┌─────┐ + │ a │ + │ --- │ + │ i64 │ + ╞═════╡ + │ 1 │ + └─────┘ """ return self.__class__(lambda plx: self._call(plx).mode()) diff --git a/narwhals/series.py b/narwhals/series.py index 9fcb07a23..9d21058f5 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -1076,21 +1076,22 @@ def shift(self, n: int) -> Self: return self._from_compliant_series(self._compliant_series.shift(n)) def sample( - self, + self: Self, n: int | None = None, - fraction: float | None = None, *, + fraction: float | None = None, with_replacement: bool = False, + seed: int | None = None, ) -> Self: """ Sample randomly from this Series. Arguments: n: Number of items to return. Cannot be used with fraction. - fraction: Fraction of items to return. Cannot be used with n. - with_replacement: Allow values to be sampled more than once. + seed: Seed for the random number generator. If set to None (default), a random + seed is generated for each sample operation. Notes: The `sample` method returns a Series with a specified number of @@ -1131,7 +1132,7 @@ def sample( """ return self._from_compliant_series( self._compliant_series.sample( - n=n, fraction=fraction, with_replacement=with_replacement + n=n, fraction=fraction, with_replacement=with_replacement, seed=seed ) ) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index dde7ca5fd..4758e938f 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -80,21 +80,25 @@ def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ... def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ... @overload def __getitem__(self, item: tuple[slice, Sequence[int]]) -> Self: ... - @overload def __getitem__(self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap] @overload + def __getitem__(self, item: tuple[slice, str]) -> Series: ... # type: ignore[overload-overlap] + @overload def __getitem__(self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ... @overload def __getitem__(self, item: tuple[slice, Sequence[str]]) -> Self: ... - @overload def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap] + @overload + def __getitem__(self, item: tuple[slice, int]) -> Series: ... # type: ignore[overload-overlap] @overload def __getitem__(self, item: Sequence[int]) -> Self: ... + @overload def __getitem__(self, item: str) -> Series: ... # type: ignore[overload-overlap] + @overload def __getitem__(self, item: Sequence[str]) -> Self: ... diff --git a/narwhals/translate.py b/narwhals/translate.py index 69a99ea2b..7b7d09de5 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -391,9 +391,6 @@ def from_native( # noqa: PLR0915 level="full", ) - # TODO(marco): write all of these in terms of `is_` rather - # than `get_` + walrus - # Polars elif is_polars_dataframe(native_object): if series_only: diff --git a/pyproject.toml b/pyproject.toml index 5ec7fef5f..efc02f9bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "narwhals" -version = "1.7.0" +version = "1.8.1" authors = [ { name="Marco Gorelli", email="33491632+MarcoGorelli@users.noreply.github.com" }, ] @@ -114,6 +114,7 @@ filterwarnings = [ 'ignore:.*You are using pyarrow version', 'ignore:.*but when imported by', 'ignore:Distributing .*This may take some time', + 'ignore:.*The default coalesce behavior' ] xfail_strict = true markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] diff --git a/requirements-dev.txt b/requirements-dev.txt index 23ff1757e..44d57530d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,4 +10,5 @@ pytest-cov pytest-env hypothesis scikit-learn +typing_extensions dask[dataframe]; python_version >= '3.9' diff --git a/tests/conftest.py b/tests/conftest.py index 011b83265..85c296daf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ from narwhals.typing import IntoDataFrame from narwhals.typing import IntoFrame from narwhals.utils import parse_version +from tests.utils import Constructor with contextlib.suppress(ImportError): import modin.pandas # noqa: F401 @@ -107,10 +108,10 @@ def pyarrow_table_constructor(obj: Any) -> IntoDataFrame: @pytest.fixture(params=eager_constructors) -def constructor_eager(request: Any) -> Callable[[Any], IntoDataFrame]: +def constructor_eager(request: pytest.FixtureRequest) -> Callable[[Any], IntoDataFrame]: return request.param # type: ignore[no-any-return] @pytest.fixture(params=[*eager_constructors, *lazy_constructors]) -def constructor(request: Any) -> Callable[[Any], Any]: +def constructor(request: pytest.FixtureRequest) -> Constructor: return request.param # type: ignore[no-any-return] diff --git a/tests/expr_and_series/abs_test.py b/tests/expr_and_series/abs_test.py index e684528b8..286bcca19 100644 --- a/tests/expr_and_series/abs_test.py +++ b/tests/expr_and_series/abs_test.py @@ -1,10 +1,11 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_abs(constructor: Any) -> None: +def test_abs(constructor: Constructor) -> None: df = nw.from_native(constructor({"a": [1, 2, 3, -4, 5]})) result = df.select(b=nw.col("a").abs()) expected = {"b": [1, 2, 3, 4, 5]} diff --git a/tests/expr_and_series/all_horizontal_test.py b/tests/expr_and_series/all_horizontal_test.py index 256c45deb..bc9f80358 100644 --- a/tests/expr_and_series/all_horizontal_test.py +++ b/tests/expr_and_series/all_horizontal_test.py @@ -3,12 +3,13 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts @pytest.mark.parametrize("expr1", ["a", nw.col("a")]) @pytest.mark.parametrize("expr2", ["b", nw.col("b")]) -def test_allh(constructor: Any, expr1: Any, expr2: Any) -> None: +def test_allh(constructor: Constructor, expr1: Any, expr2: Any) -> None: data = { "a": [False, False, True], "b": [False, True, True], diff --git a/tests/expr_and_series/any_all_test.py b/tests/expr_and_series/any_all_test.py index 09cc8c9e3..834a91202 100644 --- a/tests/expr_and_series/any_all_test.py +++ b/tests/expr_and_series/any_all_test.py @@ -1,10 +1,11 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_any_all(constructor: Any) -> None: +def test_any_all(constructor: Constructor) -> None: df = nw.from_native( constructor( { diff --git a/tests/expr_and_series/any_horizontal_test.py b/tests/expr_and_series/any_horizontal_test.py index 1f19aa304..1b6dfd48d 100644 --- a/tests/expr_and_series/any_horizontal_test.py +++ b/tests/expr_and_series/any_horizontal_test.py @@ -3,12 +3,13 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts @pytest.mark.parametrize("expr1", ["a", nw.col("a")]) @pytest.mark.parametrize("expr2", ["b", nw.col("b")]) -def test_anyh(constructor: Any, expr1: Any, expr2: Any) -> None: +def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None: data = { "a": [False, False, True], "b": [False, True, True], diff --git a/tests/expr_and_series/arg_true_test.py b/tests/expr_and_series/arg_true_test.py index eaa3d1ba6..7e1262aa8 100644 --- a/tests/expr_and_series/arg_true_test.py +++ b/tests/expr_and_series/arg_true_test.py @@ -3,10 +3,11 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_arg_true(constructor: Any, request: Any) -> None: +def test_arg_true(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, None, None, 3]})) diff --git a/tests/expr_and_series/arithmetic_test.py b/tests/expr_and_series/arithmetic_test.py index 7ff945c80..e431aebbe 100644 --- a/tests/expr_and_series/arithmetic_test.py +++ b/tests/expr_and_series/arithmetic_test.py @@ -12,6 +12,7 @@ import narwhals.stable.v1 as nw from narwhals.utils import parse_version +from tests.utils import Constructor from tests.utils import compare_dicts @@ -32,8 +33,8 @@ def test_arithmetic_expr( attr: str, rhs: Any, expected: list[Any], - constructor: Any, - request: Any, + constructor: Constructor, + request: pytest.FixtureRequest, ) -> None: if attr == "__mod__" and any( x in str(constructor) for x in ["pandas_pyarrow", "modin"] @@ -62,8 +63,8 @@ def test_right_arithmetic_expr( attr: str, rhs: Any, expected: list[Any], - constructor: Any, - request: Any, + constructor: Constructor, + request: pytest.FixtureRequest, ) -> None: if attr == "__rmod__" and any( x in str(constructor) for x in ["pandas_pyarrow", "modin"] @@ -94,7 +95,7 @@ def test_arithmetic_series( rhs: Any, expected: list[Any], constructor_eager: Any, - request: Any, + request: pytest.FixtureRequest, ) -> None: if attr == "__mod__" and any( x in str(constructor_eager) for x in ["pandas_pyarrow", "modin"] @@ -124,7 +125,7 @@ def test_right_arithmetic_series( rhs: Any, expected: list[Any], constructor_eager: Any, - request: Any, + request: pytest.FixtureRequest, ) -> None: if attr == "__rmod__" and any( x in str(constructor_eager) for x in ["pandas_pyarrow", "modin"] @@ -137,7 +138,9 @@ def test_right_arithmetic_series( compare_dicts(result, {"a": expected}) -def test_truediv_same_dims(constructor_eager: Any, request: Any) -> None: +def test_truediv_same_dims( + constructor_eager: Any, request: pytest.FixtureRequest +) -> None: if "polars" in str(constructor_eager): # https://github.com/pola-rs/polars/issues/17760 request.applymarker(pytest.mark.xfail) diff --git a/tests/expr_and_series/binary_test.py b/tests/expr_and_series/binary_test.py index 2d55af228..1ce76d9d2 100644 --- a/tests/expr_and_series/binary_test.py +++ b/tests/expr_and_series/binary_test.py @@ -1,10 +1,9 @@ -from typing import Any - import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_expr_binary(constructor: Any) -> None: +def test_expr_binary(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_raw = constructor(data) result = nw.from_native(df_raw).with_columns( diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index 0b496d7ae..00f242148 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -1,11 +1,10 @@ -from typing import Any - import pandas as pd import pyarrow as pa import pytest import narwhals.stable.v1 as nw from narwhals.utils import parse_version +from tests.utils import Constructor data = { "a": [1], @@ -46,7 +45,7 @@ @pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning") -def test_cast(constructor: Any, request: Any) -> None: +def test_cast(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "pyarrow_table_constructor" in str(constructor) and parse_version( pa.__version__ ) <= (15,): # pragma: no cover @@ -96,7 +95,7 @@ def test_cast(constructor: Any, request: Any) -> None: assert dict(result.collect_schema()) == expected -def test_cast_series(constructor: Any, request: Any) -> None: +def test_cast_series(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "pyarrow_table_constructor" in str(constructor) and parse_version( pa.__version__ ) <= (15,): # pragma: no cover @@ -162,7 +161,9 @@ def test_cast_string() -> None: assert str(result.dtype) in ("string", "object", "dtype('O')") -def test_cast_raises_for_unknown_dtype(constructor: Any, request: Any) -> None: +def test_cast_raises_for_unknown_dtype( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: if "pyarrow_table_constructor" in str(constructor) and parse_version( pa.__version__ ) <= (15,): # pragma: no cover diff --git a/tests/expr_and_series/cat/get_categories_test.py b/tests/expr_and_series/cat/get_categories_test.py index 6432826c2..122f3c83e 100644 --- a/tests/expr_and_series/cat/get_categories_test.py +++ b/tests/expr_and_series/cat/get_categories_test.py @@ -12,7 +12,7 @@ data = {"a": ["one", "two", "two"]} -def test_get_categories(request: Any, constructor_eager: Any) -> None: +def test_get_categories(request: pytest.FixtureRequest, constructor_eager: Any) -> None: if "pyarrow_table" in str(constructor_eager) and parse_version( pa.__version__ ) < parse_version("15.0.0"): diff --git a/tests/expr_and_series/clip_test.py b/tests/expr_and_series/clip_test.py index 909b153b7..d3f90633c 100644 --- a/tests/expr_and_series/clip_test.py +++ b/tests/expr_and_series/clip_test.py @@ -1,10 +1,11 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_clip(constructor: Any) -> None: +def test_clip(constructor: Constructor) -> None: df = nw.from_native(constructor({"a": [1, 2, 3, -4, 5]})) result = df.select( lower_only=nw.col("a").clip(lower_bound=3), diff --git a/tests/expr_and_series/count_test.py b/tests/expr_and_series/count_test.py index 208df3bc1..580bd202b 100644 --- a/tests/expr_and_series/count_test.py +++ b/tests/expr_and_series/count_test.py @@ -1,10 +1,11 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_count(constructor: Any) -> None: +def test_count(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, None, 6], "z": [7.0, None, None]} df = nw.from_native(constructor(data)) result = df.select(nw.col("a", "b", "z").count()) diff --git a/tests/expr_and_series/cum_sum_test.py b/tests/expr_and_series/cum_sum_test.py index e169b28f9..94897a850 100644 --- a/tests/expr_and_series/cum_sum_test.py +++ b/tests/expr_and_series/cum_sum_test.py @@ -1,6 +1,7 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -10,7 +11,7 @@ } -def test_cum_sum_simple(constructor: Any) -> None: +def test_cum_sum_simple(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.col("a", "b", "c").cum_sum()) expected = { diff --git a/tests/expr_and_series/diff_test.py b/tests/expr_and_series/diff_test.py index f38b96e00..33445f763 100644 --- a/tests/expr_and_series/diff_test.py +++ b/tests/expr_and_series/diff_test.py @@ -5,6 +5,7 @@ import narwhals.stable.v1 as nw from narwhals.utils import parse_version +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -14,7 +15,7 @@ } -def test_diff(constructor: Any, request: Any) -> None: +def test_diff(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "pyarrow_table_constructor" in str(constructor) and parse_version( pa.__version__ ) < (13,): @@ -31,7 +32,7 @@ def test_diff(constructor: Any, request: Any) -> None: compare_dicts(result, expected) -def test_diff_series(constructor_eager: Any, request: Any) -> None: +def test_diff_series(constructor_eager: Any, request: pytest.FixtureRequest) -> None: if "pyarrow_table_constructor" in str(constructor_eager) and parse_version( pa.__version__ ) < (13,): diff --git a/tests/expr_and_series/double_selected_test.py b/tests/expr_and_series/double_selected_test.py index 7b8fd6703..88826fb40 100644 --- a/tests/expr_and_series/double_selected_test.py +++ b/tests/expr_and_series/double_selected_test.py @@ -1,10 +1,9 @@ -from typing import Any - import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_double_selected(constructor: Any) -> None: +def test_double_selected(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/double_test.py b/tests/expr_and_series/double_test.py index 3a6b622b8..8f19e0202 100644 --- a/tests/expr_and_series/double_test.py +++ b/tests/expr_and_series/double_test.py @@ -1,10 +1,9 @@ -from typing import Any - import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_double(constructor: Any) -> None: +def test_double(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) result = df.with_columns(nw.all() * 2) @@ -12,7 +11,7 @@ def test_double(constructor: Any) -> None: compare_dicts(result, expected) -def test_double_alias(constructor: Any) -> None: +def test_double_alias(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) result = df.with_columns(nw.col("a").alias("o"), nw.all() * 2) diff --git a/tests/expr_and_series/drop_nulls_test.py b/tests/expr_and_series/drop_nulls_test.py index f4c8e2d7a..bc06eec3a 100644 --- a/tests/expr_and_series/drop_nulls_test.py +++ b/tests/expr_and_series/drop_nulls_test.py @@ -5,10 +5,11 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_drop_nulls(constructor: Any, request: Any) -> None: +def test_drop_nulls(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) data = { diff --git a/tests/expr_and_series/dt/datetime_attributes_test.py b/tests/expr_and_series/dt/datetime_attributes_test.py index 22e20590e..5b9519f57 100644 --- a/tests/expr_and_series/dt/datetime_attributes_test.py +++ b/tests/expr_and_series/dt/datetime_attributes_test.py @@ -7,6 +7,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -34,7 +35,10 @@ ], ) def test_datetime_attributes( - request: Any, constructor: Any, attribute: str, expected: list[int] + request: pytest.FixtureRequest, + constructor: Constructor, + attribute: str, + expected: list[int], ) -> None: if ( attribute == "date" @@ -67,7 +71,10 @@ def test_datetime_attributes( ], ) def test_datetime_attributes_series( - request: Any, constructor_eager: Any, attribute: str, expected: list[int] + request: pytest.FixtureRequest, + constructor_eager: Any, + attribute: str, + expected: list[int], ) -> None: if ( attribute == "date" @@ -83,7 +90,9 @@ def test_datetime_attributes_series( compare_dicts(result, {"a": expected}) -def test_datetime_chained_attributes(request: Any, constructor_eager: Any) -> None: +def test_datetime_chained_attributes( + request: pytest.FixtureRequest, constructor_eager: Any +) -> None: if "pandas" in str(constructor_eager) and "pyarrow" not in str(constructor_eager): request.applymarker(pytest.mark.xfail) if "cudf" in str(constructor_eager): diff --git a/tests/expr_and_series/dt/datetime_duration_test.py b/tests/expr_and_series/dt/datetime_duration_test.py index 50d254ba3..da5ff325b 100644 --- a/tests/expr_and_series/dt/datetime_duration_test.py +++ b/tests/expr_and_series/dt/datetime_duration_test.py @@ -11,6 +11,7 @@ import narwhals.stable.v1 as nw from narwhals.utils import parse_version +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -37,8 +38,8 @@ ], ) def test_duration_attributes( - request: Any, - constructor: Any, + request: pytest.FixtureRequest, + constructor: Constructor, attribute: str, expected_a: list[int], expected_b: list[int], @@ -46,6 +47,8 @@ def test_duration_attributes( ) -> None: if parse_version(pd.__version__) < (2, 2) and "pandas_pyarrow" in str(constructor): request.applymarker(pytest.mark.xfail) + if "cudf" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -70,7 +73,7 @@ def test_duration_attributes( ], ) def test_duration_attributes_series( - request: Any, + request: pytest.FixtureRequest, constructor_eager: Any, attribute: str, expected_a: list[int], @@ -81,6 +84,8 @@ def test_duration_attributes_series( constructor_eager ): request.applymarker(pytest.mark.xfail) + if "cudf" in str(constructor_eager): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor_eager(data), eager_only=True) diff --git a/tests/expr_and_series/dt/to_string_test.py b/tests/expr_and_series/dt/to_string_test.py index 7cbbf72f2..6017c33d2 100644 --- a/tests/expr_and_series/dt/to_string_test.py +++ b/tests/expr_and_series/dt/to_string_test.py @@ -6,6 +6,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts from tests.utils import is_windows @@ -57,7 +58,7 @@ def test_dt_to_string_series(constructor_eager: Any, fmt: str) -> None: ], ) @pytest.mark.skipif(is_windows(), reason="pyarrow breaking on windows") -def test_dt_to_string_expr(constructor: Any, fmt: str) -> None: +def test_dt_to_string_expr(constructor: Constructor, fmt: str) -> None: input_frame = nw.from_native(constructor(data)) expected_col = [datetime.strftime(d, fmt) for d in data["a"]] @@ -130,10 +131,8 @@ def test_dt_to_string_iso_local_datetime_series( ) @pytest.mark.skipif(is_windows(), reason="pyarrow breaking on windows") def test_dt_to_string_iso_local_datetime_expr( - request: Any, constructor: Any, data: datetime, expected: str + constructor: Constructor, data: datetime, expected: str ) -> None: - if "modin" in str(constructor): - request.applymarker(pytest.mark.xfail) df = constructor({"a": [data]}) result = nw.from_native(df).with_columns( @@ -166,11 +165,8 @@ def test_dt_to_string_iso_local_date_series( ) @pytest.mark.skipif(is_windows(), reason="pyarrow breaking on windows") def test_dt_to_string_iso_local_date_expr( - request: Any, constructor: Any, data: datetime, expected: str + constructor: Constructor, data: datetime, expected: str ) -> None: - if "modin" in str(constructor): - request.applymarker(pytest.mark.xfail) - df = constructor({"a": [data]}) result = nw.from_native(df).with_columns( nw.col("a").dt.to_string("%Y-%m-%d").alias("b") diff --git a/tests/expr_and_series/fill_null_test.py b/tests/expr_and_series/fill_null_test.py index 04d6d076f..6efde5ac0 100644 --- a/tests/expr_and_series/fill_null_test.py +++ b/tests/expr_and_series/fill_null_test.py @@ -1,6 +1,7 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -10,7 +11,7 @@ } -def test_fill_null(constructor: Any) -> None: +def test_fill_null(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.with_columns(nw.col("a", "b", "c").fill_null(99)) diff --git a/tests/expr_and_series/filter_test.py b/tests/expr_and_series/filter_test.py index b55a0368e..80267d1d0 100644 --- a/tests/expr_and_series/filter_test.py +++ b/tests/expr_and_series/filter_test.py @@ -3,6 +3,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -13,7 +14,7 @@ } -def test_filter(constructor: Any, request: Any) -> None: +def test_filter(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/gather_every_test.py b/tests/expr_and_series/gather_every_test.py index b00014f20..e01294ef9 100644 --- a/tests/expr_and_series/gather_every_test.py +++ b/tests/expr_and_series/gather_every_test.py @@ -3,6 +3,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": list(range(10))} @@ -10,7 +11,9 @@ @pytest.mark.parametrize("n", [1, 2, 3]) @pytest.mark.parametrize("offset", [1, 2, 3]) -def test_gather_every_expr(constructor: Any, n: int, offset: int, request: Any) -> None: +def test_gather_every_expr( + constructor: Constructor, n: int, offset: int, request: pytest.FixtureRequest +) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/head_test.py b/tests/expr_and_series/head_test.py index ef2ed1bf1..2a6326921 100644 --- a/tests/expr_and_series/head_test.py +++ b/tests/expr_and_series/head_test.py @@ -5,11 +5,12 @@ import pytest import narwhals as nw +from tests.utils import Constructor from tests.utils import compare_dicts @pytest.mark.parametrize("n", [2, -1]) -def test_head(constructor: Any, n: int, request: Any) -> None: +def test_head(constructor: Constructor, n: int, request: pytest.FixtureRequest) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and n < 0: diff --git a/tests/expr_and_series/is_between_test.py b/tests/expr_and_series/is_between_test.py index 10c61e9e1..0a9e578ea 100644 --- a/tests/expr_and_series/is_between_test.py +++ b/tests/expr_and_series/is_between_test.py @@ -5,6 +5,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -21,7 +22,7 @@ ("none", [False, True, True, False]), ], ) -def test_is_between(constructor: Any, closed: str, expected: list[bool]) -> None: +def test_is_between(constructor: Constructor, closed: str, expected: list[bool]) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.col("a").is_between(1, 5, closed=closed)) expected_dict = {"a": expected} diff --git a/tests/expr_and_series/is_duplicated_test.py b/tests/expr_and_series/is_duplicated_test.py index 71d165749..7859aed02 100644 --- a/tests/expr_and_series/is_duplicated_test.py +++ b/tests/expr_and_series/is_duplicated_test.py @@ -1,12 +1,13 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": [1, 1, 2], "b": [1, 2, 3], "index": [0, 1, 2]} -def test_is_duplicated_expr(constructor: Any) -> None: +def test_is_duplicated_expr(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.col("a", "b").is_duplicated(), "index").sort("index") expected = {"a": [True, True, False], "b": [False, False, False], "index": [0, 1, 2]} diff --git a/tests/expr_and_series/is_first_distinct_test.py b/tests/expr_and_series/is_first_distinct_test.py index 8521661d6..93ffc5d37 100644 --- a/tests/expr_and_series/is_first_distinct_test.py +++ b/tests/expr_and_series/is_first_distinct_test.py @@ -1,6 +1,7 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -9,7 +10,7 @@ } -def test_is_first_distinct_expr(constructor: Any) -> None: +def test_is_first_distinct_expr(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.all().is_first_distinct()) expected = { diff --git a/tests/expr_and_series/is_in_test.py b/tests/expr_and_series/is_in_test.py index 40c7b2718..085b1efbe 100644 --- a/tests/expr_and_series/is_in_test.py +++ b/tests/expr_and_series/is_in_test.py @@ -3,12 +3,13 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": [1, 4, 2, 5]} -def test_expr_is_in(constructor: Any) -> None: +def test_expr_is_in(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.col("a").is_in([4, 5])) expected = {"a": [False, True, False, True]} @@ -24,7 +25,7 @@ def test_ser_is_in(constructor_eager: Any) -> None: compare_dicts(result, expected) -def test_is_in_other(constructor: Any) -> None: +def test_is_in_other(constructor: Constructor) -> None: df_raw = constructor(data) with pytest.raises( NotImplementedError, diff --git a/tests/expr_and_series/is_last_distinct_test.py b/tests/expr_and_series/is_last_distinct_test.py index 2e4709efd..00db7f735 100644 --- a/tests/expr_and_series/is_last_distinct_test.py +++ b/tests/expr_and_series/is_last_distinct_test.py @@ -1,6 +1,7 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -9,7 +10,7 @@ } -def test_is_last_distinct_expr(constructor: Any) -> None: +def test_is_last_distinct_expr(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.all().is_last_distinct()) expected = { diff --git a/tests/expr_and_series/is_null_test.py b/tests/expr_and_series/is_null_test.py index 07465fd9b..85ba55dc4 100644 --- a/tests/expr_and_series/is_null_test.py +++ b/tests/expr_and_series/is_null_test.py @@ -1,10 +1,11 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_null(constructor: Any) -> None: +def test_null(constructor: Constructor) -> None: data_na = {"a": [None, 3, 2], "z": [7.0, None, None]} expected = {"a": [True, False, False], "z": [True, False, False]} df = nw.from_native(constructor(data_na)) diff --git a/tests/expr_and_series/is_unique_test.py b/tests/expr_and_series/is_unique_test.py index d203c1635..b10f7a68f 100644 --- a/tests/expr_and_series/is_unique_test.py +++ b/tests/expr_and_series/is_unique_test.py @@ -1,6 +1,7 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -10,7 +11,7 @@ } -def test_is_unique_expr(constructor: Any) -> None: +def test_is_unique_expr(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.col("a", "b").is_unique(), "index").sort("index") expected = { diff --git a/tests/expr_and_series/len_test.py b/tests/expr_and_series/len_test.py index 8a52dd327..b1e1674bf 100644 --- a/tests/expr_and_series/len_test.py +++ b/tests/expr_and_series/len_test.py @@ -3,10 +3,11 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_len_no_filter(constructor: Any) -> None: +def test_len_no_filter(constructor: Constructor) -> None: data = {"a": list("xyz"), "b": [1, 2, 1]} expected = {"l": [3], "l2": [6]} df = nw.from_native(constructor(data)).select( @@ -17,7 +18,7 @@ def test_len_no_filter(constructor: Any) -> None: compare_dicts(df, expected) -def test_len_chaining(constructor: Any, request: Any) -> None: +def test_len_chaining(constructor: Constructor, request: pytest.FixtureRequest) -> None: data = {"a": list("xyz"), "b": [1, 2, 1]} expected = {"a1": [2], "a2": [1]} if "dask" in str(constructor): @@ -30,7 +31,7 @@ def test_len_chaining(constructor: Any, request: Any) -> None: compare_dicts(df, expected) -def test_namespace_len(constructor: Any) -> None: +def test_namespace_len(constructor: Constructor) -> None: df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})).select( nw.len(), a=nw.len() ) diff --git a/tests/expr_and_series/max_test.py b/tests/expr_and_series/max_test.py index 83f24dcfe..1ea32531e 100644 --- a/tests/expr_and_series/max_test.py +++ b/tests/expr_and_series/max_test.py @@ -5,13 +5,14 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} @pytest.mark.parametrize("expr", [nw.col("a", "b", "z").max(), nw.max("a", "b", "z")]) -def test_expr_max_expr(constructor: Any, expr: nw.Expr) -> None: +def test_expr_max_expr(constructor: Constructor, expr: nw.Expr) -> None: df = nw.from_native(constructor(data)) result = df.select(expr) expected = {"a": [3], "b": [6], "z": [9.0]} diff --git a/tests/expr_and_series/mean_horizontal_test.py b/tests/expr_and_series/mean_horizontal_test.py index d42d5e324..f4ad35b92 100644 --- a/tests/expr_and_series/mean_horizontal_test.py +++ b/tests/expr_and_series/mean_horizontal_test.py @@ -3,11 +3,12 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts @pytest.mark.parametrize("col_expr", [nw.col("a"), "a"]) -def test_meanh(constructor: Any, col_expr: Any) -> None: +def test_meanh(constructor: Constructor, col_expr: Any) -> None: data = {"a": [1, 3, None, None], "b": [4, None, 6, None]} df = nw.from_native(constructor(data)) result = df.select(horizontal_mean=nw.mean_horizontal(col_expr, nw.col("b"))) diff --git a/tests/expr_and_series/mean_test.py b/tests/expr_and_series/mean_test.py index f648c9b32..50e6fd862 100644 --- a/tests/expr_and_series/mean_test.py +++ b/tests/expr_and_series/mean_test.py @@ -5,13 +5,14 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": [1, 3, 2], "b": [4, 4, 7], "z": [7.0, 8, 9]} @pytest.mark.parametrize("expr", [nw.col("a", "b", "z").mean(), nw.mean("a", "b", "z")]) -def test_expr_mean_expr(constructor: Any, expr: nw.Expr) -> None: +def test_expr_mean_expr(constructor: Constructor, expr: nw.Expr) -> None: df = nw.from_native(constructor(data)) result = df.select(expr) expected = {"a": [2.0], "b": [5.0], "z": [8.0]} diff --git a/tests/expr_and_series/min_test.py b/tests/expr_and_series/min_test.py index 460e5646b..f6e98e416 100644 --- a/tests/expr_and_series/min_test.py +++ b/tests/expr_and_series/min_test.py @@ -5,13 +5,14 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} @pytest.mark.parametrize("expr", [nw.col("a", "b", "z").min(), nw.min("a", "b", "z")]) -def test_expr_min_expr(constructor: Any, expr: nw.Expr) -> None: +def test_expr_min_expr(constructor: Constructor, expr: nw.Expr) -> None: df = nw.from_native(constructor(data)) result = df.select(expr) expected = {"a": [1], "b": [4], "z": [7.0]} diff --git a/tests/expr_and_series/mode_test.py b/tests/expr_and_series/mode_test.py index 33a0bef5a..8e39405af 100644 --- a/tests/expr_and_series/mode_test.py +++ b/tests/expr_and_series/mode_test.py @@ -1,8 +1,11 @@ from typing import Any +import polars as pl import pytest import narwhals.stable.v1 as nw +from narwhals.utils import parse_version +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -11,7 +14,9 @@ } -def test_mode_single_expr(constructor: Any, request: Any) -> None: +def test_mode_single_expr( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) @@ -21,8 +26,12 @@ def test_mode_single_expr(constructor: Any, request: Any) -> None: compare_dicts(result, expected) -def test_mode_multi_expr(constructor: Any, request: Any) -> None: - if "dask" in str(constructor): +def test_mode_multi_expr( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "dask" in str(constructor) or ( + "polars" in str(constructor) and parse_version(pl.__version__) >= (1, 7, 0) + ): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.col("a", "b").mode()).sort("a", "b") diff --git a/tests/expr_and_series/n_unique_test.py b/tests/expr_and_series/n_unique_test.py index f11be2b1c..3790bb1f3 100644 --- a/tests/expr_and_series/n_unique_test.py +++ b/tests/expr_and_series/n_unique_test.py @@ -1,6 +1,7 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -9,7 +10,7 @@ } -def test_n_unique(constructor: Any) -> None: +def test_n_unique(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.all().n_unique()) expected = {"a": [3], "b": [4]} diff --git a/tests/expr_and_series/name/keep_test.py b/tests/expr_and_series/name/keep_test.py index 0b43abe40..be112d716 100644 --- a/tests/expr_and_series/name/keep_test.py +++ b/tests/expr_and_series/name/keep_test.py @@ -1,32 +1,32 @@ from __future__ import annotations from contextlib import nullcontext as does_not_raise -from typing import Any import polars as pl import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]} -def test_keep(constructor: Any) -> None: +def test_keep(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.keep()) expected = {k: [e * 2 for e in v] for k, v in data.items()} compare_dicts(result, expected) -def test_keep_after_alias(constructor: Any) -> None: +def test_keep_after_alias(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo")).alias("alias_for_foo").name.keep()) expected = {"foo": data["foo"]} compare_dicts(result, expected) -def test_keep_raise_anonymous(constructor: Any) -> None: +def test_keep_raise_anonymous(constructor: Constructor) -> None: df_raw = constructor(data) df = nw.from_native(df_raw) diff --git a/tests/expr_and_series/name/map_test.py b/tests/expr_and_series/name/map_test.py index ff039e30d..5fad9f930 100644 --- a/tests/expr_and_series/name/map_test.py +++ b/tests/expr_and_series/name/map_test.py @@ -1,12 +1,12 @@ from __future__ import annotations from contextlib import nullcontext as does_not_raise -from typing import Any import polars as pl import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]} @@ -16,21 +16,21 @@ def map_func(s: str | None) -> str: return str(s)[::-1].lower() -def test_map(constructor: Any) -> None: +def test_map(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.map(function=map_func)) expected = {map_func(k): [e * 2 for e in v] for k, v in data.items()} compare_dicts(result, expected) -def test_map_after_alias(constructor: Any) -> None: +def test_map_after_alias(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo")).alias("alias_for_foo").name.map(function=map_func)) expected = {map_func("foo"): data["foo"]} compare_dicts(result, expected) -def test_map_raise_anonymous(constructor: Any) -> None: +def test_map_raise_anonymous(constructor: Constructor) -> None: df_raw = constructor(data) df = nw.from_native(df_raw) diff --git a/tests/expr_and_series/name/prefix_test.py b/tests/expr_and_series/name/prefix_test.py index f538d4136..95d72914f 100644 --- a/tests/expr_and_series/name/prefix_test.py +++ b/tests/expr_and_series/name/prefix_test.py @@ -1,33 +1,33 @@ from __future__ import annotations from contextlib import nullcontext as does_not_raise -from typing import Any import polars as pl import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]} prefix = "with_prefix_" -def test_prefix(constructor: Any) -> None: +def test_prefix(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.prefix(prefix)) expected = {prefix + str(k): [e * 2 for e in v] for k, v in data.items()} compare_dicts(result, expected) -def test_suffix_after_alias(constructor: Any) -> None: +def test_suffix_after_alias(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo")).alias("alias_for_foo").name.prefix(prefix)) expected = {prefix + "foo": data["foo"]} compare_dicts(result, expected) -def test_prefix_raise_anonymous(constructor: Any) -> None: +def test_prefix_raise_anonymous(constructor: Constructor) -> None: df_raw = constructor(data) df = nw.from_native(df_raw) diff --git a/tests/expr_and_series/name/suffix_test.py b/tests/expr_and_series/name/suffix_test.py index 0e952449b..1802f26f6 100644 --- a/tests/expr_and_series/name/suffix_test.py +++ b/tests/expr_and_series/name/suffix_test.py @@ -1,33 +1,33 @@ from __future__ import annotations from contextlib import nullcontext as does_not_raise -from typing import Any import polars as pl import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]} suffix = "_with_suffix" -def test_suffix(constructor: Any) -> None: +def test_suffix(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.suffix(suffix)) expected = {str(k) + suffix: [e * 2 for e in v] for k, v in data.items()} compare_dicts(result, expected) -def test_suffix_after_alias(constructor: Any) -> None: +def test_suffix_after_alias(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo")).alias("alias_for_foo").name.suffix(suffix)) expected = {"foo" + suffix: data["foo"]} compare_dicts(result, expected) -def test_suffix_raise_anonymous(constructor: Any) -> None: +def test_suffix_raise_anonymous(constructor: Constructor) -> None: df_raw = constructor(data) df = nw.from_native(df_raw) diff --git a/tests/expr_and_series/name/to_lowercase_test.py b/tests/expr_and_series/name/to_lowercase_test.py index a9e8bfcfd..fedac9cd3 100644 --- a/tests/expr_and_series/name/to_lowercase_test.py +++ b/tests/expr_and_series/name/to_lowercase_test.py @@ -1,32 +1,32 @@ from __future__ import annotations from contextlib import nullcontext as does_not_raise -from typing import Any import polars as pl import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]} -def test_to_lowercase(constructor: Any) -> None: +def test_to_lowercase(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.to_lowercase()) expected = {k.lower(): [e * 2 for e in v] for k, v in data.items()} compare_dicts(result, expected) -def test_to_lowercase_after_alias(constructor: Any) -> None: +def test_to_lowercase_after_alias(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("BAR")).alias("ALIAS_FOR_BAR").name.to_lowercase()) expected = {"bar": data["BAR"]} compare_dicts(result, expected) -def test_to_lowercase_raise_anonymous(constructor: Any) -> None: +def test_to_lowercase_raise_anonymous(constructor: Constructor) -> None: df_raw = constructor(data) df = nw.from_native(df_raw) diff --git a/tests/expr_and_series/name/to_uppercase_test.py b/tests/expr_and_series/name/to_uppercase_test.py index 035dfeff2..29b70bd99 100644 --- a/tests/expr_and_series/name/to_uppercase_test.py +++ b/tests/expr_and_series/name/to_uppercase_test.py @@ -1,32 +1,32 @@ from __future__ import annotations from contextlib import nullcontext as does_not_raise -from typing import Any import polars as pl import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]} -def test_to_uppercase(constructor: Any) -> None: +def test_to_uppercase(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.to_uppercase()) expected = {k.upper(): [e * 2 for e in v] for k, v in data.items()} compare_dicts(result, expected) -def test_to_uppercase_after_alias(constructor: Any) -> None: +def test_to_uppercase_after_alias(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo")).alias("alias_for_foo").name.to_uppercase()) expected = {"FOO": data["foo"]} compare_dicts(result, expected) -def test_to_uppercase_raise_anonymous(constructor: Any) -> None: +def test_to_uppercase_raise_anonymous(constructor: Constructor) -> None: df_raw = constructor(data) df = nw.from_native(df_raw) diff --git a/tests/expr_and_series/null_count_test.py b/tests/expr_and_series/null_count_test.py index a6cb58f71..6be15ab32 100644 --- a/tests/expr_and_series/null_count_test.py +++ b/tests/expr_and_series/null_count_test.py @@ -1,6 +1,7 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -9,7 +10,7 @@ } -def test_null_count_expr(constructor: Any) -> None: +def test_null_count_expr(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.all().null_count()) expected = { diff --git a/tests/expr_and_series/operators_test.py b/tests/expr_and_series/operators_test.py index 113824a94..e3f39465c 100644 --- a/tests/expr_and_series/operators_test.py +++ b/tests/expr_and_series/operators_test.py @@ -5,6 +5,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts @@ -20,7 +21,7 @@ ], ) def test_comparand_operators_scalar_expr( - constructor: Any, operator: str, expected: list[bool] + constructor: Constructor, operator: str, expected: list[bool] ) -> None: data = {"a": [0, 1, 2]} df = nw.from_native(constructor(data)) @@ -40,7 +41,7 @@ def test_comparand_operators_scalar_expr( ], ) def test_comparand_operators_expr( - constructor: Any, operator: str, expected: list[bool] + constructor: Constructor, operator: str, expected: list[bool] ) -> None: data = {"a": [0, 1, 1], "b": [0, 0, 2]} df = nw.from_native(constructor(data)) @@ -56,7 +57,7 @@ def test_comparand_operators_expr( ], ) def test_logic_operators_expr( - constructor: Any, operator: str, expected: list[bool] + constructor: Constructor, operator: str, expected: list[bool] ) -> None: data = {"a": [True, True, False, False], "b": [True, False, True, False]} df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/over_test.py b/tests/expr_and_series/over_test.py index 17b07cc1e..2abc9a699 100644 --- a/tests/expr_and_series/over_test.py +++ b/tests/expr_and_series/over_test.py @@ -1,9 +1,9 @@ from contextlib import nullcontext as does_not_raise -from typing import Any import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -13,7 +13,7 @@ } -def test_over_single(constructor: Any) -> None: +def test_over_single(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) expected = { "a": ["a", "a", "b", "b", "b"], @@ -36,7 +36,7 @@ def test_over_single(constructor: Any) -> None: compare_dicts(result, expected) -def test_over_multiple(constructor: Any) -> None: +def test_over_multiple(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) expected = { "a": ["a", "a", "b", "b", "b"], @@ -59,7 +59,7 @@ def test_over_multiple(constructor: Any) -> None: compare_dicts(result, expected) -def test_over_invalid(request: Any, constructor: Any) -> None: +def test_over_invalid(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "polars" in str(constructor): request.applymarker(pytest.mark.xfail) diff --git a/tests/expr_and_series/pipe_test.py b/tests/expr_and_series/pipe_test.py index 55de3548b..2134a931b 100644 --- a/tests/expr_and_series/pipe_test.py +++ b/tests/expr_and_series/pipe_test.py @@ -1,13 +1,14 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts input_list = {"a": [2, 4, 6, 8]} expected = [4, 16, 36, 64] -def test_pipe_expr(constructor: Any) -> None: +def test_pipe_expr(constructor: Constructor) -> None: df = nw.from_native(constructor(input_list)) e = df.select(nw.col("a").pipe(lambda x: x**2)) compare_dicts(e, {"a": expected}) diff --git a/tests/expr_and_series/quantile_test.py b/tests/expr_and_series/quantile_test.py index 5b8ff9334..aae2b3647 100644 --- a/tests/expr_and_series/quantile_test.py +++ b/tests/expr_and_series/quantile_test.py @@ -7,6 +7,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts @@ -22,10 +23,10 @@ ) @pytest.mark.filterwarnings("ignore:the `interpolation=` argument to percentile") def test_quantile_expr( - constructor: Any, + constructor: Constructor, interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], expected: dict[str, list[float]], - request: Any, + request: pytest.FixtureRequest, ) -> None: if "dask" in str(constructor) and interpolation != "linear": request.applymarker(pytest.mark.xfail) diff --git a/tests/expr_and_series/reduction_test.py b/tests/expr_and_series/reduction_test.py index 60750444e..e22080e62 100644 --- a/tests/expr_and_series/reduction_test.py +++ b/tests/expr_and_series/reduction_test.py @@ -5,6 +5,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts @@ -26,7 +27,7 @@ ids=range(5), ) def test_scalar_reduction_select( - constructor: Any, expr: list[Any], expected: dict[str, list[Any]] + constructor: Constructor, expr: list[Any], expected: dict[str, list[Any]] ) -> None: data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = nw.from_native(constructor(data)) @@ -52,7 +53,7 @@ def test_scalar_reduction_select( ids=range(5), ) def test_scalar_reduction_with_columns( - constructor: Any, expr: list[Any], expected: dict[str, list[Any]] + constructor: Constructor, expr: list[Any], expected: dict[str, list[Any]] ) -> None: data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/round_test.py b/tests/expr_and_series/round_test.py index 769e4be11..37d6ce131 100644 --- a/tests/expr_and_series/round_test.py +++ b/tests/expr_and_series/round_test.py @@ -5,11 +5,12 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts @pytest.mark.parametrize("decimals", [0, 1, 2]) -def test_round(constructor: Any, decimals: int) -> None: +def test_round(constructor: Constructor, decimals: int) -> None: data = {"a": [2.12345, 2.56789, 3.901234]} df_raw = constructor(data) df = nw.from_native(df_raw) diff --git a/tests/expr_and_series/sample_test.py b/tests/expr_and_series/sample_test.py index c64703d3c..eb6d853ec 100644 --- a/tests/expr_and_series/sample_test.py +++ b/tests/expr_and_series/sample_test.py @@ -1,11 +1,11 @@ -from typing import Any - import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor +from tests.utils import compare_dicts -def test_expr_sample(constructor: Any, request: Any) -> None: +def test_expr_sample(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})).lazy() @@ -19,7 +19,9 @@ def test_expr_sample(constructor: Any, request: Any) -> None: assert result_series == expected_series -def test_expr_sample_fraction(constructor: Any, request: Any) -> None: +def test_expr_sample_fraction( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 2, 3] * 10, "b": [4, 5, 6] * 10})).lazy() @@ -31,3 +33,37 @@ def test_expr_sample_fraction(constructor: Any, request: Any) -> None: result_series = df.collect()["a"].sample(fraction=0.1).shape expected_series = (3,) assert result_series == expected_series + + +def test_sample_with_seed( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "dask" in str(constructor): + request.applymarker(pytest.mark.xfail) + + size, n = 100, 10 + df = nw.from_native(constructor({"a": list(range(size))})).lazy() + expected = {"res1": [True], "res2": [False]} + result = ( + df.select( + seed1=nw.col("a").sample(n=n, seed=123), + seed2=nw.col("a").sample(n=n, seed=123), + seed3=nw.col("a").sample(n=n, seed=42), + ) + .select( + res1=(nw.col("seed1") == nw.col("seed2")).all(), + res2=(nw.col("seed1") == nw.col("seed3")).all(), + ) + .collect() + ) + + compare_dicts(result, expected) + + series = df.collect()["a"] + seed1 = series.sample(n=n, seed=123) + seed2 = series.sample(n=n, seed=123) + seed3 = series.sample(n=n, seed=42) + + compare_dicts( + {"res1": [(seed1 == seed2).all()], "res2": [(seed1 == seed3).all()]}, expected + ) diff --git a/tests/expr_and_series/shift_test.py b/tests/expr_and_series/shift_test.py index 02dbed6b0..b165adf12 100644 --- a/tests/expr_and_series/shift_test.py +++ b/tests/expr_and_series/shift_test.py @@ -3,6 +3,7 @@ import pyarrow as pa import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -13,7 +14,7 @@ } -def test_shift(constructor: Any) -> None: +def test_shift(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.with_columns(nw.col("a", "b", "c").shift(2)).filter(nw.col("i") > 1) expected = { diff --git a/tests/expr_and_series/std_test.py b/tests/expr_and_series/std_test.py index 0c7e2f338..400a6e0af 100644 --- a/tests/expr_and_series/std_test.py +++ b/tests/expr_and_series/std_test.py @@ -1,12 +1,13 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} -def test_std(constructor: Any) -> None: +def test_std(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select( nw.col("a").std().alias("a_ddof_default"), diff --git a/tests/expr_and_series/str/contains_test.py b/tests/expr_and_series/str/contains_test.py index 5cc90f4ad..312de50a4 100644 --- a/tests/expr_and_series/str/contains_test.py +++ b/tests/expr_and_series/str/contains_test.py @@ -2,8 +2,10 @@ import pandas as pd import polars as pl +import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"pets": ["cat", "dog", "rabbit and parrot", "dove"]} @@ -12,7 +14,10 @@ df_polars = pl.DataFrame(data) -def test_contains(constructor: Any) -> None: +def test_contains(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "cudf" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.with_columns( nw.col("pets").str.contains("(?i)parrot|Dove").alias("result") @@ -24,7 +29,10 @@ def test_contains(constructor: Any) -> None: compare_dicts(result, expected) -def test_contains_series(constructor_eager: Any) -> None: +def test_contains_series(constructor_eager: Any, request: pytest.FixtureRequest) -> None: + if "cudf" in str(constructor_eager): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor_eager(data), eager_only=True) result = df.with_columns( case_insensitive_match=df["pets"].str.contains("(?i)parrot|Dove") diff --git a/tests/expr_and_series/str/head_test.py b/tests/expr_and_series/str/head_test.py index 1160920fd..a4b3e7296 100644 --- a/tests/expr_and_series/str/head_test.py +++ b/tests/expr_and_series/str/head_test.py @@ -1,12 +1,13 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": ["foo", "bars"]} -def test_str_head(constructor: Any) -> None: +def test_str_head(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.col("a").str.head(3)) expected = { diff --git a/tests/expr_and_series/str/replace_test.py b/tests/expr_and_series/str/replace_test.py index 95b5bd87c..b0cffb1b4 100644 --- a/tests/expr_and_series/str/replace_test.py +++ b/tests/expr_and_series/str/replace_test.py @@ -5,6 +5,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts replace_data = [ @@ -92,7 +93,7 @@ def test_str_replace_all_series( replace_data, ) def test_str_replace_expr( - constructor: Any, + constructor: Constructor, data: dict[str, list[str]], pattern: str, value: str, @@ -113,7 +114,7 @@ def test_str_replace_expr( replace_all_data, ) def test_str_replace_all_expr( - constructor: Any, + constructor: Constructor, data: dict[str, list[str]], pattern: str, value: str, diff --git a/tests/expr_and_series/str/slice_test.py b/tests/expr_and_series/str/slice_test.py index e4e7905f2..e7fe0efa1 100644 --- a/tests/expr_and_series/str/slice_test.py +++ b/tests/expr_and_series/str/slice_test.py @@ -5,6 +5,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": ["fdas", "edfas"]} @@ -15,7 +16,7 @@ [(1, 2, {"a": ["da", "df"]}), (-2, None, {"a": ["as", "as"]})], ) def test_str_slice( - constructor: Any, offset: int, length: int | None, expected: Any + constructor: Constructor, offset: int, length: int | None, expected: Any ) -> None: df = nw.from_native(constructor(data)) result_frame = df.select(nw.col("a").str.slice(offset, length)) diff --git a/tests/expr_and_series/str/starts_with_ends_with_test.py b/tests/expr_and_series/str/starts_with_ends_with_test.py index a5101edcb..e8b0afaa9 100644 --- a/tests/expr_and_series/str/starts_with_ends_with_test.py +++ b/tests/expr_and_series/str/starts_with_ends_with_test.py @@ -3,6 +3,7 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor # Don't move this into typechecking block, for coverage # purposes @@ -11,7 +12,7 @@ data = {"a": ["fdas", "edfas"]} -def test_ends_with(constructor: Any) -> None: +def test_ends_with(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.col("a").str.ends_with("das")) expected = { @@ -29,7 +30,7 @@ def test_ends_with_series(constructor_eager: Any) -> None: compare_dicts(result, expected) -def test_starts_with(constructor: Any) -> None: +def test_starts_with(constructor: Constructor) -> None: df = nw.from_native(constructor(data)).lazy() result = df.select(nw.col("a").str.starts_with("fda")) expected = { diff --git a/tests/expr_and_series/str/strip_chars_test.py b/tests/expr_and_series/str/strip_chars_test.py index f6cbcc4fa..3d5b74456 100644 --- a/tests/expr_and_series/str/strip_chars_test.py +++ b/tests/expr_and_series/str/strip_chars_test.py @@ -5,6 +5,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": ["foobar", "bar\n", " baz"]} @@ -17,7 +18,9 @@ ("foo", {"a": ["bar", "bar\n", " baz"]}), ], ) -def test_str_strip_chars(constructor: Any, characters: str | None, expected: Any) -> None: +def test_str_strip_chars( + constructor: Constructor, characters: str | None, expected: Any +) -> None: df = nw.from_native(constructor(data)) result_frame = df.select(nw.col("a").str.strip_chars(characters)) compare_dicts(result_frame, expected) diff --git a/tests/expr_and_series/str/tail_test.py b/tests/expr_and_series/str/tail_test.py index c863cca0e..92d474262 100644 --- a/tests/expr_and_series/str/tail_test.py +++ b/tests/expr_and_series/str/tail_test.py @@ -1,12 +1,13 @@ from typing import Any import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": ["foo", "bars"]} -def test_str_tail(constructor: Any) -> None: +def test_str_tail(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) expected = {"a": ["foo", "ars"]} diff --git a/tests/expr_and_series/str/to_datetime_test.py b/tests/expr_and_series/str/to_datetime_test.py index 8c3d1a51a..ad666aa8a 100644 --- a/tests/expr_and_series/str/to_datetime_test.py +++ b/tests/expr_and_series/str/to_datetime_test.py @@ -1,11 +1,10 @@ -from typing import Any - import narwhals.stable.v1 as nw +from tests.utils import Constructor data = {"a": ["2020-01-01T12:34:56"]} -def test_to_datetime(constructor: Any) -> None: +def test_to_datetime(constructor: Constructor) -> None: result = ( nw.from_native(constructor(data)) .lazy() diff --git a/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py b/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py index 4d2f2f745..877409138 100644 --- a/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py +++ b/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py @@ -7,6 +7,7 @@ import narwhals.stable.v1 as nw from narwhals.utils import parse_version +from tests.utils import Constructor from tests.utils import compare_dicts @@ -26,10 +27,10 @@ ], ) def test_str_to_uppercase( - constructor: Any, + constructor: Constructor, data: dict[str, list[str]], expected: dict[str, list[str]], - request: Any, + request: pytest.FixtureRequest, ) -> None: df = nw.from_native(constructor(data)) result_frame = df.select(nw.col("a").str.to_uppercase()) @@ -70,7 +71,7 @@ def test_str_to_uppercase_series( constructor_eager: Any, data: dict[str, list[str]], expected: dict[str, list[str]], - request: Any, + request: pytest.FixtureRequest, ) -> None: df = nw.from_native(constructor_eager(data), eager_only=True) @@ -80,6 +81,7 @@ def test_str_to_uppercase_series( "pandas_constructor", "pandas_nullable_constructor", "polars_eager_constructor", + "cudf_constructor", ) ): # We are marking it xfail for these conditions above @@ -107,7 +109,7 @@ def test_str_to_uppercase_series( ], ) def test_str_to_lowercase( - constructor: Any, + constructor: Constructor, data: dict[str, list[str]], expected: dict[str, list[str]], ) -> None: diff --git a/tests/expr_and_series/sum_horizontal_test.py b/tests/expr_and_series/sum_horizontal_test.py index 4c4ab924c..bfaab7238 100644 --- a/tests/expr_and_series/sum_horizontal_test.py +++ b/tests/expr_and_series/sum_horizontal_test.py @@ -3,11 +3,12 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts @pytest.mark.parametrize("col_expr", [nw.col("a"), "a"]) -def test_sumh(constructor: Any, col_expr: Any) -> None: +def test_sumh(constructor: Constructor, col_expr: Any) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) result = df.with_columns(horizontal_sum=nw.sum_horizontal(col_expr, nw.col("b"))) @@ -20,7 +21,7 @@ def test_sumh(constructor: Any, col_expr: Any) -> None: compare_dicts(result, expected) -def test_sumh_nullable(constructor: Any) -> None: +def test_sumh_nullable(constructor: Constructor) -> None: data = {"a": [1, 8, 3], "b": [4, 5, None]} expected = {"hsum": [5, 13, 3]} diff --git a/tests/expr_and_series/sum_test.py b/tests/expr_and_series/sum_test.py index c61a9ed79..8059a097d 100644 --- a/tests/expr_and_series/sum_test.py +++ b/tests/expr_and_series/sum_test.py @@ -5,13 +5,14 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} @pytest.mark.parametrize("expr", [nw.col("a", "b", "z").sum(), nw.sum("a", "b", "z")]) -def test_expr_sum_expr(constructor: Any, expr: nw.Expr) -> None: +def test_expr_sum_expr(constructor: Constructor, expr: nw.Expr) -> None: df = nw.from_native(constructor(data)) result = df.select(expr) expected = {"a": [6], "b": [14], "z": [24.0]} diff --git a/tests/expr_and_series/tail_test.py b/tests/expr_and_series/tail_test.py index be17ffb4e..fc3e6159a 100644 --- a/tests/expr_and_series/tail_test.py +++ b/tests/expr_and_series/tail_test.py @@ -1,15 +1,14 @@ -from __future__ import annotations - from typing import Any import pytest import narwhals as nw +from tests.utils import Constructor from tests.utils import compare_dicts @pytest.mark.parametrize("n", [2, -1]) -def test_head(constructor: Any, n: int, request: Any) -> None: +def test_head(constructor: Constructor, n: int, request: pytest.FixtureRequest) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and n < 0: diff --git a/tests/expr_and_series/unary_test.py b/tests/expr_and_series/unary_test.py index 7df0099dd..dabab0c03 100644 --- a/tests/expr_and_series/unary_test.py +++ b/tests/expr_and_series/unary_test.py @@ -3,10 +3,11 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_unary(constructor: Any, request: Any) -> None: +def test_unary(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} diff --git a/tests/expr_and_series/unique_test.py b/tests/expr_and_series/unique_test.py index 488d793cd..5639179ba 100644 --- a/tests/expr_and_series/unique_test.py +++ b/tests/expr_and_series/unique_test.py @@ -3,12 +3,13 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": [1, 1, 2]} -def test_unique_expr(constructor: Any, request: Any) -> None: +def test_unique_expr(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 5b60edfa9..ed4a2ccd6 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -6,6 +6,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -23,7 +24,7 @@ } -def test_when(constructor: Any) -> None: +def test_when(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(value=3).alias("a_when")) expected = { @@ -32,7 +33,7 @@ def test_when(constructor: Any) -> None: compare_dicts(result, expected) -def test_when_otherwise(constructor: Any) -> None: +def test_when_otherwise(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when")) expected = { @@ -41,7 +42,7 @@ def test_when_otherwise(constructor: Any) -> None: compare_dicts(result, expected) -def test_multiple_conditions(constructor: Any) -> None: +def test_multiple_conditions(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select( nw.when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when") @@ -52,13 +53,15 @@ def test_multiple_conditions(constructor: Any) -> None: compare_dicts(result, expected) -def test_no_arg_when_fail(constructor: Any) -> None: +def test_no_arg_when_fail(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) with pytest.raises((TypeError, ValueError)): df.select(nw.when().then(value=3).alias("a_when")) -def test_value_numpy_array(request: Any, constructor: Any) -> None: +def test_value_numpy_array( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) @@ -86,7 +89,7 @@ def test_value_series(constructor_eager: Any) -> None: compare_dicts(result, expected) -def test_value_expression(constructor: Any) -> None: +def test_value_expression(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(nw.col("a") + 9).alias("a_when")) expected = { @@ -95,7 +98,9 @@ def test_value_expression(constructor: Any) -> None: compare_dicts(result, expected) -def test_otherwise_numpy_array(request: Any, constructor: Any) -> None: +def test_otherwise_numpy_array( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) @@ -123,7 +128,9 @@ def test_otherwise_series(constructor_eager: Any) -> None: compare_dicts(result, expected) -def test_otherwise_expression(request: Any, constructor: Any) -> None: +def test_otherwise_expression( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) @@ -137,7 +144,9 @@ def test_otherwise_expression(request: Any, constructor: Any) -> None: compare_dicts(result, expected) -def test_when_then_otherwise_into_expr(request: Any, constructor: Any) -> None: +def test_when_then_otherwise_into_expr( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) diff --git a/tests/frame/add_test.py b/tests/frame/add_test.py index b885bb05d..c95fbae97 100644 --- a/tests/frame/add_test.py +++ b/tests/frame/add_test.py @@ -1,10 +1,9 @@ -from typing import Any - import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_add(constructor: Any) -> None: +def test_add(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) result = df.with_columns( diff --git a/tests/frame/array_dunder_test.py b/tests/frame/array_dunder_test.py index 8f9cbd16a..8a082bb1f 100644 --- a/tests/frame/array_dunder_test.py +++ b/tests/frame/array_dunder_test.py @@ -11,7 +11,7 @@ from tests.utils import compare_dicts -def test_array_dunder(request: Any, constructor_eager: Any) -> None: +def test_array_dunder(request: pytest.FixtureRequest, constructor_eager: Any) -> None: if "pyarrow_table" in str(constructor_eager) and parse_version( pa.__version__ ) < parse_version("16.0.0"): # pragma: no cover @@ -22,7 +22,9 @@ def test_array_dunder(request: Any, constructor_eager: Any) -> None: np.testing.assert_array_equal(result, np.array([[1], [2], [3]], dtype="int64")) -def test_array_dunder_with_dtype(request: Any, constructor_eager: Any) -> None: +def test_array_dunder_with_dtype( + request: pytest.FixtureRequest, constructor_eager: Any +) -> None: if "pyarrow_table" in str(constructor_eager) and parse_version( pa.__version__ ) < parse_version("16.0.0"): # pragma: no cover @@ -33,7 +35,9 @@ def test_array_dunder_with_dtype(request: Any, constructor_eager: Any) -> None: np.testing.assert_array_equal(result, np.array([[1], [2], [3]], dtype=object)) -def test_array_dunder_with_copy(request: Any, constructor_eager: Any) -> None: +def test_array_dunder_with_copy( + request: pytest.FixtureRequest, constructor_eager: Any +) -> None: if "pyarrow_table" in str(constructor_eager) and parse_version(pa.__version__) < ( 16, 0, diff --git a/tests/frame/clone_test.py b/tests/frame/clone_test.py index 6e8b19beb..e94183e2e 100644 --- a/tests/frame/clone_test.py +++ b/tests/frame/clone_test.py @@ -1,12 +1,11 @@ -from typing import Any - import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_clone(request: Any, constructor: Any) -> None: +def test_clone(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) if "pyarrow_table" in str(constructor): diff --git a/tests/frame/columns_test.py b/tests/frame/columns_test.py index 157051ba3..90a9c922d 100644 --- a/tests/frame/columns_test.py +++ b/tests/frame/columns_test.py @@ -1,12 +1,11 @@ -from typing import Any - import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor @pytest.mark.filterwarnings("ignore:Determining|Resolving.*") -def test_columns(constructor: Any) -> None: +def test_columns(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) result = df.columns diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index a52759128..926f3f988 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -1,12 +1,11 @@ -from typing import Any - import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_concat_horizontal(constructor: Any) -> None: +def test_concat_horizontal(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_left = nw.from_native(constructor(data)).lazy() @@ -27,7 +26,7 @@ def test_concat_horizontal(constructor: Any) -> None: nw.concat([]) -def test_concat_vertical(constructor: Any) -> None: +def test_concat_vertical(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_left = ( nw.from_native(constructor(data)).lazy().rename({"a": "c", "b": "d"}).drop("z") diff --git a/tests/frame/double_test.py b/tests/frame/double_test.py index 5d52d0d26..6840145ec 100644 --- a/tests/frame/double_test.py +++ b/tests/frame/double_test.py @@ -1,10 +1,9 @@ -from typing import Any - import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_double(constructor: Any) -> None: +def test_double(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) diff --git a/tests/frame/drop_nulls_test.py b/tests/frame/drop_nulls_test.py index 58c9486ed..4c2276030 100644 --- a/tests/frame/drop_nulls_test.py +++ b/tests/frame/drop_nulls_test.py @@ -1,10 +1,9 @@ from __future__ import annotations -from typing import Any - import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -13,7 +12,7 @@ } -def test_drop_nulls(constructor: Any) -> None: +def test_drop_nulls(constructor: Constructor) -> None: result = nw.from_native(constructor(data)).drop_nulls() expected = { "a": [2.0, 4.0], @@ -23,7 +22,7 @@ def test_drop_nulls(constructor: Any) -> None: @pytest.mark.parametrize("subset", ["a", ["a"]]) -def test_drop_nulls_subset(constructor: Any, subset: str | list[str]) -> None: +def test_drop_nulls_subset(constructor: Constructor, subset: str | list[str]) -> None: result = nw.from_native(constructor(data)).drop_nulls(subset=subset) expected = { "a": [1, 2.0, 4.0], diff --git a/tests/frame/drop_test.py b/tests/frame/drop_test.py index db039fcb2..f8fc33254 100644 --- a/tests/frame/drop_test.py +++ b/tests/frame/drop_test.py @@ -1,6 +1,7 @@ from __future__ import annotations from contextlib import nullcontext as does_not_raise +from typing import TYPE_CHECKING from typing import Any import polars as pl @@ -11,6 +12,9 @@ from narwhals._exceptions import ColumnNotFoundError from narwhals.utils import parse_version +if TYPE_CHECKING: + from tests.utils import Constructor + @pytest.mark.parametrize( ("to_drop", "expected"), @@ -20,7 +24,7 @@ (["abc", "b"], ["z"]), ], ) -def test_drop(constructor: Any, to_drop: list[str], expected: list[str]) -> None: +def test_drop(constructor: Constructor, to_drop: list[str], expected: list[str]) -> None: data = {"abc": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) assert df.drop(to_drop).collect_schema().names() == expected @@ -38,7 +42,13 @@ def test_drop(constructor: Any, to_drop: list[str], expected: list[str]) -> None (False, does_not_raise()), ], ) -def test_drop_strict(request: Any, constructor: Any, strict: bool, context: Any) -> None: # noqa: FBT001 +def test_drop_strict( + request: pytest.FixtureRequest, + constructor: Constructor, + context: Any, + *, + strict: bool, +) -> None: if ( "polars_lazy" in str(request) and parse_version(pl.__version__) < (1, 0, 0) diff --git a/tests/frame/filter_test.py b/tests/frame/filter_test.py index e7a289feb..9c9b1b6fd 100644 --- a/tests/frame/filter_test.py +++ b/tests/frame/filter_test.py @@ -1,13 +1,13 @@ from contextlib import nullcontext as does_not_raise -from typing import Any import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_filter(constructor: Any) -> None: +def test_filter(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) result = df.filter(nw.col("a") > 1) @@ -15,7 +15,7 @@ def test_filter(constructor: Any) -> None: compare_dicts(result, expected) -def test_filter_with_boolean_list(constructor: Any) -> None: +def test_filter_with_boolean_list(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) diff --git a/tests/frame/gather_every_test.py b/tests/frame/gather_every_test.py index 90b06e3d6..40e18a30b 100644 --- a/tests/frame/gather_every_test.py +++ b/tests/frame/gather_every_test.py @@ -1,8 +1,7 @@ -from typing import Any - import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": list(range(10))} @@ -10,7 +9,7 @@ @pytest.mark.parametrize("n", [1, 2, 3]) @pytest.mark.parametrize("offset", [1, 2, 3]) -def test_gather_every(constructor: Any, n: int, offset: int) -> None: +def test_gather_every(constructor: Constructor, n: int, offset: int) -> None: df = nw.from_native(constructor(data)) result = df.gather_every(n=n, offset=offset) expected = {"a": data["a"][offset::n]} diff --git a/tests/frame/head_test.py b/tests/frame/head_test.py index e4b762f48..7234828b0 100644 --- a/tests/frame/head_test.py +++ b/tests/frame/head_test.py @@ -1,12 +1,11 @@ from __future__ import annotations -from typing import Any - import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_head(constructor: Any) -> None: +def test_head(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} expected = {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]} diff --git a/tests/frame/join_test.py b/tests/frame/join_test.py index 18e9aae64..85c76eba7 100644 --- a/tests/frame/join_test.py +++ b/tests/frame/join_test.py @@ -11,10 +11,11 @@ import narwhals.stable.v1 as nw from narwhals.utils import Implementation from narwhals.utils import parse_version +from tests.utils import Constructor from tests.utils import compare_dicts -def test_inner_join_two_keys(constructor: Any) -> None: +def test_inner_join_two_keys(constructor: Constructor) -> None: data = { "antananarivo": [1, 3, 2], "bob": [4, 4, 6], @@ -43,7 +44,7 @@ def test_inner_join_two_keys(constructor: Any) -> None: compare_dicts(result_on, expected) -def test_inner_join_single_key(constructor: Any) -> None: +def test_inner_join_single_key(constructor: Constructor) -> None: data = { "antananarivo": [1, 3, 2], "bob": [4, 4, 6], @@ -73,7 +74,7 @@ def test_inner_join_single_key(constructor: Any) -> None: compare_dicts(result_on, expected) -def test_cross_join(constructor: Any) -> None: +def test_cross_join(constructor: Constructor) -> None: data = {"antananarivo": [1, 3, 2]} df = nw.from_native(constructor(data)) result = df.join(df, how="cross").sort("antananarivo", "antananarivo_right") # type: ignore[arg-type] @@ -91,7 +92,7 @@ def test_cross_join(constructor: Any) -> None: @pytest.mark.parametrize("how", ["inner", "left"]) @pytest.mark.parametrize("suffix", ["_right", "_custom_suffix"]) -def test_suffix(constructor: Any, how: str, suffix: str) -> None: +def test_suffix(constructor: Constructor, how: str, suffix: str) -> None: data = { "antananarivo": [1, 3, 2], "bob": [4, 4, 6], @@ -111,7 +112,7 @@ def test_suffix(constructor: Any, how: str, suffix: str) -> None: @pytest.mark.parametrize("suffix", ["_right", "_custom_suffix"]) -def test_cross_join_suffix(constructor: Any, suffix: str) -> None: +def test_cross_join_suffix(constructor: Constructor, suffix: str) -> None: data = {"antananarivo": [1, 3, 2]} df = nw.from_native(constructor(data)) result = df.join(df, how="cross", suffix=suffix).sort( # type: ignore[arg-type] @@ -154,7 +155,7 @@ def test_cross_join_non_pandas() -> None: ], ) def test_anti_join( - constructor: Any, + constructor: Constructor, join_key: list[str], filter_expr: nw.Expr, expected: dict[str, list[Any]], @@ -169,6 +170,11 @@ def test_anti_join( @pytest.mark.parametrize( ("join_key", "filter_expr", "expected"), [ + ( + "antananarivo", + (nw.col("bob") > 5), + {"antananarivo": [2], "bob": [6], "zorro": [9]}, + ), ( ["antananarivo"], (nw.col("bob") > 5), @@ -187,7 +193,7 @@ def test_anti_join( ], ) def test_semi_join( - constructor: Any, + constructor: Constructor, join_key: list[str], filter_expr: nw.Expr, expected: dict[str, list[Any]], @@ -202,7 +208,7 @@ def test_semi_join( @pytest.mark.parametrize("how", ["right", "full"]) -def test_join_not_implemented(constructor: Any, how: str) -> None: +def test_join_not_implemented(constructor: Constructor, how: str) -> None: data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} df = nw.from_native(constructor(data)) @@ -216,16 +222,20 @@ def test_join_not_implemented(constructor: Any, how: str) -> None: @pytest.mark.filterwarnings("ignore:the default coalesce behavior") -def test_left_join(constructor: Any) -> None: +def test_left_join(constructor: Constructor) -> None: data_left = { "antananarivo": [1.0, 2, 3], "bob": [4.0, 5, 6], "index": [0.0, 1.0, 2.0], } - data_right = {"antananarivo": [1.0, 2, 3], "c": [4.0, 5, 7], "index": [0.0, 1.0, 2.0]} + data_right = { + "antananarivo": [1.0, 2, 3], + "co": [4.0, 5, 7], + "index": [0.0, 1.0, 2.0], + } df_left = nw.from_native(constructor(data_left)) df_right = nw.from_native(constructor(data_right)) - result = df_left.join(df_right, left_on="bob", right_on="c", how="left").select( # type: ignore[arg-type] + result = df_left.join(df_right, left_on="bob", right_on="co", how="left").select( # type: ignore[arg-type] nw.all().fill_null(float("nan")) ) result = result.sort("index") @@ -236,11 +246,24 @@ def test_left_join(constructor: Any) -> None: "antananarivo_right": [1, 2, float("nan")], "index": [0, 1, 2], } + result_on_list = df_left.join( + df_right, # type: ignore[arg-type] + on=["antananarivo", "index"], + how="left", + ).select(nw.all().fill_null(float("nan"))) + result_on_list = result_on_list.sort("index") + expected_on_list = { + "antananarivo": [1, 2, 3], + "bob": [4, 5, 6], + "index": [0, 1, 2], + "co": [4, 5, 7], + } compare_dicts(result, expected) + compare_dicts(result_on_list, expected_on_list) @pytest.mark.filterwarnings("ignore: the default coalesce behavior") -def test_left_join_multiple_column(constructor: Any) -> None: +def test_left_join_multiple_column(constructor: Constructor) -> None: data_left = {"antananarivo": [1, 2, 3], "bob": [4, 5, 6], "index": [0, 1, 2]} data_right = {"antananarivo": [1, 2, 3], "c": [4, 5, 6], "index": [0, 1, 2]} df_left = nw.from_native(constructor(data_left)) @@ -258,7 +281,7 @@ def test_left_join_multiple_column(constructor: Any) -> None: @pytest.mark.filterwarnings("ignore: the default coalesce behavior") -def test_left_join_overlapping_column(constructor: Any) -> None: +def test_left_join_overlapping_column(constructor: Constructor) -> None: data_left = { "antananarivo": [1.0, 2, 3], "bob": [4.0, 5, 6], @@ -304,7 +327,7 @@ def test_left_join_overlapping_column(constructor: Any) -> None: @pytest.mark.parametrize("how", ["inner", "left", "semi", "anti"]) -def test_join_keys_exceptions(constructor: Any, how: str) -> None: +def test_join_keys_exceptions(constructor: Constructor, how: str) -> None: data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} df = nw.from_native(constructor(data)) @@ -330,8 +353,10 @@ def test_join_keys_exceptions(constructor: Any, how: str) -> None: df.join(df, how=how, on="antananarivo", right_on="antananarivo") # type: ignore[arg-type] -def test_joinasof_numeric(constructor: Any, request: Any) -> None: - if "pyarrow_table" in str(constructor): +def test_joinasof_numeric( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "pyarrow_table" in str(constructor) or "cudf" in str(constructor): request.applymarker(pytest.mark.xfail) if parse_version(pd.__version__) < (2, 1) and ( ("pandas_pyarrow" in str(constructor)) or ("pandas_nullable" in str(constructor)) @@ -386,8 +411,8 @@ def test_joinasof_numeric(constructor: Any, request: Any) -> None: compare_dicts(result_nearest_on, expected_nearest) -def test_joinasof_time(constructor: Any, request: Any) -> None: - if "pyarrow_table" in str(constructor): +def test_joinasof_time(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "pyarrow_table" in str(constructor) or "cudf" in str(constructor): request.applymarker(pytest.mark.xfail) if parse_version(pd.__version__) < (2, 1) and ("pandas_pyarrow" in str(constructor)): request.applymarker(pytest.mark.xfail) @@ -464,8 +489,8 @@ def test_joinasof_time(constructor: Any, request: Any) -> None: compare_dicts(result_nearest_on, expected_nearest) -def test_joinasof_by(constructor: Any, request: Any) -> None: - if "pyarrow_table" in str(constructor): +def test_joinasof_by(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "pyarrow_table" in str(constructor) or "cudf" in str(constructor): request.applymarker(pytest.mark.xfail) if parse_version(pd.__version__) < (2, 1) and ( ("pandas_pyarrow" in str(constructor)) or ("pandas_nullable" in str(constructor)) @@ -499,7 +524,7 @@ def test_joinasof_by(constructor: Any, request: Any) -> None: @pytest.mark.parametrize("strategy", ["back", "furthest"]) def test_joinasof_not_implemented( - constructor: Any, strategy: Literal["backward", "forward"] + constructor: Constructor, strategy: Literal["backward", "forward"] ) -> None: data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} df = nw.from_native(constructor(data)) @@ -516,7 +541,7 @@ def test_joinasof_not_implemented( ) -def test_joinasof_keys_exceptions(constructor: Any) -> None: +def test_joinasof_keys_exceptions(constructor: Constructor) -> None: data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} df = nw.from_native(constructor(data)) @@ -557,7 +582,7 @@ def test_joinasof_keys_exceptions(constructor: Any) -> None: df.join_asof(df, right_on="antananarivo", on="antananarivo") # type: ignore[arg-type] -def test_joinasof_by_exceptions(constructor: Any) -> None: +def test_joinasof_by_exceptions(constructor: Constructor) -> None: data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} df = nw.from_native(constructor(data)) with pytest.raises( diff --git a/tests/frame/lit_test.py b/tests/frame/lit_test.py index e5756e035..aa18edb40 100644 --- a/tests/frame/lit_test.py +++ b/tests/frame/lit_test.py @@ -7,6 +7,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts if TYPE_CHECKING: @@ -17,7 +18,9 @@ ("dtype", "expected_lit"), [(None, [2, 2, 2]), (nw.String, ["2", "2", "2"]), (nw.Float32, [2.0, 2.0, 2.0])], ) -def test_lit(constructor: Any, dtype: DType | None, expected_lit: list[Any]) -> None: +def test_lit( + constructor: Constructor, dtype: DType | None, expected_lit: list[Any] +) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_raw = constructor(data) df = nw.from_native(df_raw).lazy() @@ -31,7 +34,7 @@ def test_lit(constructor: Any, dtype: DType | None, expected_lit: list[Any]) -> compare_dicts(result, expected) -def test_lit_error(constructor: Any) -> None: +def test_lit_error(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_raw = constructor(data) df = nw.from_native(df_raw).lazy() diff --git a/tests/frame/pipe_test.py b/tests/frame/pipe_test.py index a9a50133f..b7b57e0a1 100644 --- a/tests/frame/pipe_test.py +++ b/tests/frame/pipe_test.py @@ -1,6 +1,5 @@ -from typing import Any - import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -9,7 +8,7 @@ } -def test_pipe(constructor: Any) -> None: +def test_pipe(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) columns = df.collect_schema().names() result = df.pipe(lambda _df: _df.select([x for x in columns if len(x) == 2])) diff --git a/tests/frame/rename_test.py b/tests/frame/rename_test.py index c58eccd4c..79cf3f243 100644 --- a/tests/frame/rename_test.py +++ b/tests/frame/rename_test.py @@ -1,10 +1,9 @@ -from typing import Any - import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_rename(constructor: Any) -> None: +def test_rename(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) result = df.rename({"a": "x", "b": "y"}) diff --git a/tests/frame/row_test.py b/tests/frame/row_test.py index 602c50f55..599dcaeaf 100644 --- a/tests/frame/row_test.py +++ b/tests/frame/row_test.py @@ -1,9 +1,14 @@ from typing import Any +import pytest + import narwhals.stable.v1 as nw -def test_row_column(constructor_eager: Any) -> None: +def test_row_column(request: Any, constructor_eager: Any) -> None: + if "cudf" in str(constructor_eager): + request.applymarker(pytest.mark.xfail) + data = { "a": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], "b": [11, 12, 13, 14, 15, 16], diff --git a/tests/frame/rows_test.py b/tests/frame/rows_test.py index e3123f19f..2d94ab18e 100644 --- a/tests/frame/rows_test.py +++ b/tests/frame/rows_test.py @@ -55,10 +55,14 @@ ], ) def test_iter_rows( + request: Any, constructor_eager: Any, named: bool, # noqa: FBT001 expected: list[tuple[Any, ...]] | list[dict[str, Any]], ) -> None: + if "cudf" in str(constructor_eager): + request.applymarker(pytest.mark.xfail) + data = {"a": [1, 3, 2], "_b": [4, 4, 6], "z": [7.0, 8, 9], "1": [5, 6, 7]} df = nw.from_native(constructor_eager(data), eager_only=True) result = list(df.iter_rows(named=named)) diff --git a/tests/frame/sample_test.py b/tests/frame/sample_test.py new file mode 100644 index 000000000..88d5969c3 --- /dev/null +++ b/tests/frame/sample_test.py @@ -0,0 +1,34 @@ +import narwhals.stable.v1 as nw +from tests.utils import Constructor + + +def test_sample_n(constructor_eager: Constructor) -> None: + df = nw.from_native( + constructor_eager({"a": [1, 2, 3, 4], "b": ["x", "y", "x", "y"]}), eager_only=True + ) + + result_expr = df.sample(n=2).shape + expected_expr = (2, 2) + assert result_expr == expected_expr + + +def test_sample_fraction(constructor_eager: Constructor) -> None: + df = nw.from_native( + constructor_eager({"a": [1, 2, 3, 4], "b": ["x", "y", "x", "y"]}), eager_only=True + ) + + result_expr = df.sample(fraction=0.5).shape + expected_expr = (2, 2) + assert result_expr == expected_expr + + +def test_sample_with_seed(constructor_eager: Constructor) -> None: + size, n = 100, 10 + df = nw.from_native(constructor_eager({"a": range(size)}), eager_only=True) + + r1 = nw.to_native(df.sample(n=n, seed=123)) + r2 = nw.to_native(df.sample(n=n, seed=123)) + r3 = nw.to_native(df.sample(n=n, seed=42)) + + assert r1.equals(r2) # type: ignore[union-attr] + assert not r1.equals(r3) # type: ignore[union-attr] diff --git a/tests/frame/schema_test.py b/tests/frame/schema_test.py index 6e6b33aa1..d7ba69ab2 100644 --- a/tests/frame/schema_test.py +++ b/tests/frame/schema_test.py @@ -10,6 +10,7 @@ import narwhals.stable.v1 as nw from narwhals.utils import parse_version +from tests.utils import Constructor data = { "a": [datetime(2020, 1, 1)], @@ -18,7 +19,7 @@ @pytest.mark.filterwarnings("ignore:Determining|Resolving.*") -def test_schema(constructor: Any) -> None: +def test_schema(constructor: Constructor) -> None: df = nw.from_native(constructor({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.1, 8, 9]})) result = df.schema expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} @@ -29,7 +30,7 @@ def test_schema(constructor: Any) -> None: assert result == expected -def test_collect_schema(constructor: Any) -> None: +def test_collect_schema(constructor: Constructor) -> None: df = nw.from_native(constructor({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.1, 8, 9]})) expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} @@ -58,7 +59,7 @@ def test_string_disguised_as_object() -> None: assert result["a"] == nw.String -def test_actual_object(request: Any, constructor_eager: Any) -> None: +def test_actual_object(request: pytest.FixtureRequest, constructor_eager: Any) -> None: if any(x in str(constructor_eager) for x in ("modin", "pyarrow_table")): request.applymarker(pytest.mark.xfail) diff --git a/tests/frame/select_test.py b/tests/frame/select_test.py index 450e91066..8c01be407 100644 --- a/tests/frame/select_test.py +++ b/tests/frame/select_test.py @@ -1,13 +1,12 @@ -from typing import Any - import pandas as pd import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_select(constructor: Any) -> None: +def test_select(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) result = df.select("a") @@ -15,7 +14,7 @@ def test_select(constructor: Any) -> None: compare_dicts(result, expected) -def test_empty_select(constructor: Any) -> None: +def test_empty_select(constructor: Constructor) -> None: result = nw.from_native(constructor({"a": [1, 2, 3]})).lazy().select() assert result.collect().shape == (0, 0) diff --git a/tests/frame/slice_test.py b/tests/frame/slice_test.py index 834e88bff..0867844f9 100644 --- a/tests/frame/slice_test.py +++ b/tests/frame/slice_test.py @@ -29,7 +29,9 @@ def test_slice_rows(constructor_eager: Any) -> None: compare_dicts(result, {"a": [3.0, 4.0], "b": [13, 14]}) -def test_slice_rows_with_step(request: Any, constructor_eager: Any) -> None: +def test_slice_rows_with_step( + request: pytest.FixtureRequest, constructor_eager: Any +) -> None: if "pyarrow_table" in str(constructor_eager): request.applymarker(pytest.mark.xfail) result = nw.from_native(constructor_eager(data))[1::2] @@ -147,6 +149,9 @@ def test_slice_slice_columns(constructor_eager: Any) -> None: result = df[:, [0, 2]] expected = {"a": [1, 2, 3], "c": [7, 8, 9]} compare_dicts(result, expected) + result = df[:2, [0, 2]] + expected = {"a": [1, 2], "c": [7, 8]} + compare_dicts(result, expected) result = df[["b", "c"]] expected = {"b": [4, 5, 6], "c": [7, 8, 9]} compare_dicts(result, expected) diff --git a/tests/frame/sort_test.py b/tests/frame/sort_test.py index 9e583f8ba..06f5d079f 100644 --- a/tests/frame/sort_test.py +++ b/tests/frame/sort_test.py @@ -1,10 +1,9 @@ -from typing import Any - import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_sort(constructor: Any) -> None: +def test_sort(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) result = df.sort("a", "b") diff --git a/tests/frame/tail_test.py b/tests/frame/tail_test.py index b64d9fa6c..f7e06475c 100644 --- a/tests/frame/tail_test.py +++ b/tests/frame/tail_test.py @@ -1,15 +1,15 @@ from __future__ import annotations from contextlib import nullcontext as does_not_raise -from typing import Any import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_tail(constructor: Any) -> None: +def test_tail(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9]} diff --git a/tests/frame/to_arrow_test.py b/tests/frame/to_arrow_test.py index c1f395e59..f20bdf28c 100644 --- a/tests/frame/to_arrow_test.py +++ b/tests/frame/to_arrow_test.py @@ -10,7 +10,7 @@ from narwhals.utils import parse_version -def test_to_arrow(request: Any, constructor_eager: Any) -> None: +def test_to_arrow(request: pytest.FixtureRequest, constructor_eager: Any) -> None: if "pandas" in str(constructor_eager) and parse_version(pd.__version__) < (1, 0, 0): # pyarrow requires pandas>=1.0.0 request.applymarker(pytest.mark.xfail) diff --git a/tests/frame/to_pandas_test.py b/tests/frame/to_pandas_test.py index 81685a606..671a5d857 100644 --- a/tests/frame/to_pandas_test.py +++ b/tests/frame/to_pandas_test.py @@ -14,7 +14,7 @@ parse_version(pd.__version__) < parse_version("2.0.0"), reason="too old for pandas-pyarrow", ) -def test_convert_pandas(constructor_eager: Any, request: Any) -> None: +def test_convert_pandas(constructor_eager: Any, request: pytest.FixtureRequest) -> None: if "modin" in str(constructor_eager): request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} diff --git a/tests/frame/unique_test.py b/tests/frame/unique_test.py index af61fe82b..40589c545 100644 --- a/tests/frame/unique_test.py +++ b/tests/frame/unique_test.py @@ -1,10 +1,9 @@ from __future__ import annotations -from typing import Any - import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} @@ -21,7 +20,7 @@ ], ) def test_unique( - constructor: Any, + constructor: Constructor, subset: str | list[str] | None, keep: str, expected: dict[str, list[float]], @@ -33,7 +32,7 @@ def test_unique( compare_dicts(result, expected) -def test_unique_none(constructor: Any) -> None: +def test_unique_none(constructor: Constructor) -> None: df_raw = constructor(data) df = nw.from_native(df_raw) diff --git a/tests/frame/with_columns_sequence_test.py b/tests/frame/with_columns_sequence_test.py index 123425122..49db7820b 100644 --- a/tests/frame/with_columns_sequence_test.py +++ b/tests/frame/with_columns_sequence_test.py @@ -1,9 +1,8 @@ -from typing import Any - import numpy as np import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -12,7 +11,7 @@ } -def test_with_columns(constructor: Any, request: Any) -> None: +def test_with_columns(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) result = ( diff --git a/tests/frame/with_columns_test.py b/tests/frame/with_columns_test.py index 864e689e8..44bcd39a5 100644 --- a/tests/frame/with_columns_test.py +++ b/tests/frame/with_columns_test.py @@ -1,9 +1,8 @@ -from typing import Any - import numpy as np import pandas as pd import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts @@ -18,7 +17,7 @@ def test_with_columns_int_col_name_pandas() -> None: pd.testing.assert_frame_equal(result, expected) -def test_with_columns_order(constructor: Any) -> None: +def test_with_columns_order(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) result = df.with_columns(nw.col("a") + 1, d=nw.col("a") - 1) @@ -27,14 +26,14 @@ def test_with_columns_order(constructor: Any) -> None: compare_dicts(result, expected) -def test_with_columns_empty(constructor: Any) -> None: +def test_with_columns_empty(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) result = df.select().with_columns() compare_dicts(result, {}) -def test_with_columns_order_single_row(constructor: Any) -> None: +def test_with_columns_order_single_row(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9], "i": [0, 1, 2]} df = nw.from_native(constructor(data)).filter(nw.col("i") < 1).drop("i") result = df.with_columns(nw.col("a") + 1, d=nw.col("a") - 1) diff --git a/tests/frame/with_row_index_test.py b/tests/frame/with_row_index_test.py index bc1c2fe0a..8f802de0a 100644 --- a/tests/frame/with_row_index_test.py +++ b/tests/frame/with_row_index_test.py @@ -1,6 +1,5 @@ -from typing import Any - import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -9,7 +8,7 @@ } -def test_with_row_index(constructor: Any) -> None: +def test_with_row_index(constructor: Constructor) -> None: result = nw.from_native(constructor(data)).with_row_index() expected = {"a": ["foo", "bars"], "ab": ["foo", "bars"], "index": [0, 1]} compare_dicts(result, expected) diff --git a/tests/from_dict_test.py b/tests/from_dict_test.py index cfaf99a7b..a1332908a 100644 --- a/tests/from_dict_test.py +++ b/tests/from_dict_test.py @@ -1,12 +1,11 @@ -from typing import Any - import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts -def test_from_dict(constructor: Any, request: Any) -> None: +def test_from_dict(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})) @@ -17,7 +16,9 @@ def test_from_dict(constructor: Any, request: Any) -> None: assert isinstance(result, nw.DataFrame) -def test_from_dict_schema(constructor: Any, request: Any) -> None: +def test_from_dict_schema( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) schema = {"c": nw.Int16(), "d": nw.Float32()} @@ -31,19 +32,23 @@ def test_from_dict_schema(constructor: Any, request: Any) -> None: assert result.collect_schema() == schema -def test_from_dict_without_namespace(constructor: Any) -> None: +def test_from_dict_without_namespace(constructor: Constructor) -> None: df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})).lazy().collect() result = nw.from_dict({"c": df["a"], "d": df["b"]}) compare_dicts(result, {"c": [1, 2, 3], "d": [4, 5, 6]}) -def test_from_dict_without_namespace_invalid(constructor: Any) -> None: +def test_from_dict_without_namespace_invalid( + constructor: Constructor, +) -> None: df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})).lazy().collect() with pytest.raises(TypeError, match="namespace"): nw.from_dict({"c": nw.to_native(df["a"]), "d": nw.to_native(df["b"])}) -def test_from_dict_one_native_one_narwhals(constructor: Any) -> None: +def test_from_dict_one_native_one_narwhals( + constructor: Constructor, +) -> None: df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})).lazy().collect() result = nw.from_dict({"c": nw.to_native(df["a"]), "d": df["b"]}) expected = {"c": [1, 2, 3], "d": [4, 5, 6]} diff --git a/tests/selectors_test.py b/tests/selectors_test.py index ababee4a7..c78a9eac4 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any - import pandas as pd import pyarrow as pa import pytest @@ -14,6 +12,7 @@ from narwhals.selectors import numeric from narwhals.selectors import string from narwhals.utils import parse_version +from tests.utils import Constructor from tests.utils import compare_dicts data = { @@ -24,28 +23,28 @@ } -def test_selectors(constructor: Any) -> None: +def test_selectors(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(by_dtype([nw.Int64, nw.Float64]) + 1) expected = {"a": [2, 2, 3], "c": [5.1, 6.0, 7.0]} compare_dicts(result, expected) -def test_numeric(constructor: Any) -> None: +def test_numeric(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(numeric() + 1) expected = {"a": [2, 2, 3], "c": [5.1, 6.0, 7.0]} compare_dicts(result, expected) -def test_boolean(constructor: Any) -> None: +def test_boolean(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(boolean()) expected = {"d": [True, False, True]} compare_dicts(result, expected) -def test_string(constructor: Any, request: Any) -> None: +def test_string(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "dask" in str(constructor) and parse_version(pa.__version__) < (12,): # Dask doesn't infer `'b'` as String for old PyArrow versions request.applymarker(pytest.mark.xfail) @@ -55,7 +54,7 @@ def test_string(constructor: Any, request: Any) -> None: compare_dicts(result, expected) -def test_categorical(request: Any, constructor: Any) -> None: +def test_categorical(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "pyarrow_table_constructor" in str(constructor) and parse_version( pa.__version__ ) <= (15,): # pragma: no cover @@ -82,7 +81,7 @@ def test_categorical(request: Any, constructor: Any) -> None: ], ) def test_set_ops( - constructor: Any, selector: nw.selectors.Selector, expected: list[str] + constructor: Constructor, selector: nw.selectors.Selector, expected: list[str] ) -> None: df = nw.from_native(constructor(data)) result = df.select(selector).collect_schema().names() @@ -90,7 +89,7 @@ def test_set_ops( @pytest.mark.parametrize("invalid_constructor", [pd.DataFrame, pa.table]) -def test_set_ops_invalid(invalid_constructor: Any) -> None: +def test_set_ops_invalid(invalid_constructor: Constructor) -> None: df = nw.from_native(invalid_constructor(data)) with pytest.raises(NotImplementedError): df.select(1 - numeric()) diff --git a/tests/series_only/array_dunder_test.py b/tests/series_only/array_dunder_test.py index 0449199ef..c09bea9ec 100644 --- a/tests/series_only/array_dunder_test.py +++ b/tests/series_only/array_dunder_test.py @@ -10,7 +10,7 @@ from tests.utils import compare_dicts -def test_array_dunder(request: Any, constructor_eager: Any) -> None: +def test_array_dunder(request: pytest.FixtureRequest, constructor_eager: Any) -> None: if "pyarrow_table" in str(constructor_eager) and parse_version( pa.__version__ ) < parse_version("16.0.0"): # pragma: no cover @@ -21,7 +21,9 @@ def test_array_dunder(request: Any, constructor_eager: Any) -> None: np.testing.assert_array_equal(result, np.array([1, 2, 3], dtype="int64")) -def test_array_dunder_with_dtype(request: Any, constructor_eager: Any) -> None: +def test_array_dunder_with_dtype( + request: pytest.FixtureRequest, constructor_eager: Any +) -> None: if "pyarrow_table" in str(constructor_eager) and parse_version( pa.__version__ ) < parse_version("16.0.0"): # pragma: no cover @@ -32,7 +34,9 @@ def test_array_dunder_with_dtype(request: Any, constructor_eager: Any) -> None: np.testing.assert_array_equal(result, np.array([1, 2, 3], dtype=object)) -def test_array_dunder_with_copy(request: Any, constructor_eager: Any) -> None: +def test_array_dunder_with_copy( + request: pytest.FixtureRequest, constructor_eager: Any +) -> None: if "pyarrow_table" in str(constructor_eager) and parse_version( pa.__version__ ) < parse_version("16.0.0"): # pragma: no cover diff --git a/tests/series_only/slice_test.py b/tests/series_only/slice_test.py index f9d2b4e2f..48cf15bc7 100644 --- a/tests/series_only/slice_test.py +++ b/tests/series_only/slice_test.py @@ -13,3 +13,15 @@ def test_slice(constructor_eager: Any) -> None: result = {"a": df["a"][1:]} expected = {"a": [2, 3]} compare_dicts(result, expected) + result = {"b": df[:, 1]} + expected = {"b": [4, 5, 6]} + compare_dicts(result, expected) + result = {"b": df[:, "b"]} + expected = {"b": [4, 5, 6]} + compare_dicts(result, expected) + result = {"b": df[:2, "b"]} + expected = {"b": [4, 5]} + compare_dicts(result, expected) + result = {"b": df[:2, 1]} + expected = {"b": [4, 5]} + compare_dicts(result, expected) diff --git a/tests/series_only/to_arrow_test.py b/tests/series_only/to_arrow_test.py index ebd90b7c2..5181a6786 100644 --- a/tests/series_only/to_arrow_test.py +++ b/tests/series_only/to_arrow_test.py @@ -19,7 +19,9 @@ def test_to_arrow(constructor_eager: Any) -> None: assert pc.all(pc.equal(result, pa.array(data, type=pa.int64()))) -def test_to_arrow_with_nulls(constructor_eager: Any, request: Any) -> None: +def test_to_arrow_with_nulls( + constructor_eager: Any, request: pytest.FixtureRequest +) -> None: if "pandas_constructor" in str(constructor_eager) or "modin_constructor" in str( constructor_eager ): diff --git a/tests/series_only/to_dummy_test.py b/tests/series_only/to_dummy_test.py index 404ac6321..c3d57b9ad 100644 --- a/tests/series_only/to_dummy_test.py +++ b/tests/series_only/to_dummy_test.py @@ -18,7 +18,9 @@ def test_to_dummies(constructor_eager: Any, sep: str) -> None: @pytest.mark.parametrize("sep", ["_", "-"]) -def test_to_dummies_drop_first(request: Any, constructor_eager: Any, sep: str) -> None: +def test_to_dummies_drop_first( + request: pytest.FixtureRequest, constructor_eager: Any, sep: str +) -> None: if "cudf" in str(constructor_eager): request.applymarker(pytest.mark.xfail) s = nw.from_native(constructor_eager({"a": data}), eager_only=True)["a"].alias("a") diff --git a/tests/series_only/to_list_test.py b/tests/series_only/to_list_test.py index 10e415916..11d02d0d2 100644 --- a/tests/series_only/to_list_test.py +++ b/tests/series_only/to_list_test.py @@ -8,7 +8,7 @@ data = [1, 2, 3] -def test_to_list(constructor_eager: Any, request: Any) -> None: +def test_to_list(constructor_eager: Any, request: pytest.FixtureRequest) -> None: if "cudf" in str(constructor_eager): # pragma: no cover request.applymarker(pytest.mark.xfail) s = nw.from_native(constructor_eager({"a": data}), eager_only=True)["a"] diff --git a/tests/series_only/to_numpy_test.py b/tests/series_only/to_numpy_test.py index f5ed59fe1..433ede16a 100644 --- a/tests/series_only/to_numpy_test.py +++ b/tests/series_only/to_numpy_test.py @@ -9,7 +9,7 @@ import narwhals.stable.v1 as nw -def test_to_numpy(constructor_eager: Any, request: Any) -> None: +def test_to_numpy(constructor_eager: Any, request: pytest.FixtureRequest) -> None: if "pandas_constructor" in str(constructor_eager) or "modin_constructor" in str( constructor_eager ): diff --git a/tests/series_only/to_pandas_test.py b/tests/series_only/to_pandas_test.py index 747bed8b2..30c7906c7 100644 --- a/tests/series_only/to_pandas_test.py +++ b/tests/series_only/to_pandas_test.py @@ -15,7 +15,7 @@ @pytest.mark.skipif( parse_version(pd.__version__) < parse_version("2.0.0"), reason="too old for pyarrow" ) -def test_convert(request: Any, constructor_eager: Any) -> None: +def test_convert(request: pytest.FixtureRequest, constructor_eager: Any) -> None: if any( cname in str(constructor_eager) for cname in ("pandas_nullable", "pandas_pyarrow", "modin") diff --git a/tests/series_only/value_counts_test.py b/tests/series_only/value_counts_test.py index d3d48066f..d19a1440b 100644 --- a/tests/series_only/value_counts_test.py +++ b/tests/series_only/value_counts_test.py @@ -15,7 +15,10 @@ @pytest.mark.parametrize("normalize", [True, False]) @pytest.mark.parametrize("name", [None, "count_name"]) def test_value_counts( - request: Any, constructor_eager: Any, normalize: Any, name: str | None + request: pytest.FixtureRequest, + constructor_eager: Any, + normalize: Any, + name: str | None, ) -> None: if "pandas_nullable_constructor" in str(constructor_eager) and parse_version( pd.__version__ diff --git a/tests/stable_api_test.py b/tests/stable_api_test.py index 211ba8652..375af4f4d 100644 --- a/tests/stable_api_test.py +++ b/tests/stable_api_test.py @@ -1,14 +1,13 @@ -from typing import Any - import polars as pl import pytest import narwhals as nw import narwhals.stable.v1 as nw_v1 +from tests.utils import Constructor from tests.utils import compare_dicts -def test_renamed_taxicab_norm(constructor: Any) -> None: +def test_renamed_taxicab_norm(constructor: Constructor) -> None: # Suppose we need to rename `_l1_norm` to `_taxicab_norm`. # We need `narwhals.stable.v1` to stay stable. So, we # make the change in `narwhals`, and then add the new method diff --git a/tests/test_group_by.py b/tests/test_group_by.py index 6f12d06b1..fa9c05f4b 100644 --- a/tests/test_group_by.py +++ b/tests/test_group_by.py @@ -10,6 +10,7 @@ import narwhals.stable.v1 as nw from narwhals.utils import parse_version +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": [1, 1, 3], "b": [4, 4, 6], "c": [7.0, 8, 9]} @@ -94,7 +95,7 @@ def test_group_by_iter(constructor_eager: Any) -> None: assert sorted(keys) == sorted(expected_keys) -def test_group_by_len(constructor: Any) -> None: +def test_group_by_len(constructor: Constructor) -> None: result = ( nw.from_native(constructor(data)).group_by("a").agg(nw.col("b").len()).sort("a") ) @@ -102,7 +103,7 @@ def test_group_by_len(constructor: Any) -> None: compare_dicts(result, expected) -def test_group_by_n_unique(constructor: Any) -> None: +def test_group_by_n_unique(constructor: Constructor) -> None: result = ( nw.from_native(constructor(data)) .group_by("a") @@ -113,7 +114,7 @@ def test_group_by_n_unique(constructor: Any) -> None: compare_dicts(result, expected) -def test_group_by_std(constructor: Any) -> None: +def test_group_by_std(constructor: Constructor) -> None: data = {"a": [1, 1, 2, 2], "b": [5, 4, 3, 2]} result = ( nw.from_native(constructor(data)).group_by("a").agg(nw.col("b").std()).sort("a") @@ -122,7 +123,7 @@ def test_group_by_std(constructor: Any) -> None: compare_dicts(result, expected) -def test_group_by_n_unique_w_missing(constructor: Any) -> None: +def test_group_by_n_unique_w_missing(constructor: Constructor) -> None: data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]} result = ( nw.from_native(constructor(data)) @@ -162,7 +163,7 @@ def test_group_by_empty_result_pandas() -> None: ) -def test_group_by_simple_named(constructor: Any) -> None: +def test_group_by_simple_named(constructor: Constructor) -> None: data = {"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 2, 1]} df = nw.from_native(constructor(data)).lazy() result = ( @@ -182,7 +183,7 @@ def test_group_by_simple_named(constructor: Any) -> None: compare_dicts(result, expected) -def test_group_by_simple_unnamed(constructor: Any) -> None: +def test_group_by_simple_unnamed(constructor: Constructor) -> None: data = {"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 2, 1]} df = nw.from_native(constructor(data)).lazy() result = ( @@ -202,7 +203,7 @@ def test_group_by_simple_unnamed(constructor: Any) -> None: compare_dicts(result, expected) -def test_group_by_multiple_keys(constructor: Any) -> None: +def test_group_by_multiple_keys(constructor: Constructor) -> None: data = {"a": [1, 1, 2], "b": [4, 4, 6], "c": [7, 2, 1]} df = nw.from_native(constructor(data)).lazy() result = ( @@ -223,7 +224,7 @@ def test_group_by_multiple_keys(constructor: Any) -> None: compare_dicts(result, expected) -def test_key_with_nulls(constructor: Any, request: Any) -> None: +def test_key_with_nulls(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "modin" in str(constructor): # TODO(unassigned): Modin flaky here? request.applymarker(pytest.mark.skip) @@ -248,7 +249,7 @@ def test_key_with_nulls(constructor: Any, request: Any) -> None: compare_dicts(result, expected) -def test_no_agg(constructor: Any) -> None: +def test_no_agg(constructor: Constructor) -> None: result = nw.from_native(constructor(data)).group_by(["a", "b"]).agg().sort("a", "b") expected = {"a": [1, 3], "b": [4, 6]} diff --git a/tests/tpch_q1_test.py b/tests/tpch_q1_test.py index 999f32f76..c506ee0de 100644 --- a/tests/tpch_q1_test.py +++ b/tests/tpch_q1_test.py @@ -2,7 +2,6 @@ import os from datetime import datetime -from typing import Any from unittest import mock import pandas as pd @@ -20,7 +19,7 @@ ["pandas", "polars", "pyarrow", "dask"], ) @pytest.mark.filterwarnings("ignore:.*Passing a BlockManager.*:DeprecationWarning") -def test_q1(library: str, request: Any) -> None: +def test_q1(library: str, request: pytest.FixtureRequest) -> None: if library == "pandas" and parse_version(pd.__version__) < (1, 5): request.applymarker(pytest.mark.xfail) elif library == "pandas": @@ -99,7 +98,7 @@ def test_q1(library: str, request: Any) -> None: "ignore:.*Passing a BlockManager.*:DeprecationWarning", "ignore:.*Complex.*:UserWarning", ) -def test_q1_w_generic_funcs(library: str, request: Any) -> None: +def test_q1_w_generic_funcs(library: str, request: pytest.FixtureRequest) -> None: if library == "pandas" and parse_version(pd.__version__) < (1, 5): request.applymarker(pytest.mark.xfail) elif library == "pandas": diff --git a/tests/utils.py b/tests/utils.py index 6ab703c3b..b13bec192 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,11 +4,21 @@ import sys import warnings from typing import Any +from typing import Callable from typing import Iterator from typing import Sequence import pandas as pd +from narwhals.typing import IntoFrame + +if sys.version_info >= (3, 10): + from typing import TypeAlias # pragma: no cover +else: + from typing_extensions import TypeAlias # pragma: no cover + +Constructor: TypeAlias = Callable[[Any], IntoFrame] + def zip_strict(left: Sequence[Any], right: Sequence[Any]) -> Iterator[Any]: if len(left) != len(right): diff --git a/tpch/execute/q10.py b/tpch/execute/q10.py index e1d56d36b..124bf0f7d 100644 --- a/tpch/execute/q10.py +++ b/tpch/execute/q10.py @@ -17,3 +17,7 @@ tool = "pyarrow" fn = IO_FUNCS[tool] print(q10.query(fn(customer), fn(nation), fn(lineitem), fn(orders))) + +tool = "dask" +fn = IO_FUNCS[tool] +print(q10.query(fn(customer), fn(nation), fn(lineitem), fn(orders)).compute()) diff --git a/tpch/execute/q11.py b/tpch/execute/q11.py index a6b830f30..8c0a2e649 100644 --- a/tpch/execute/q11.py +++ b/tpch/execute/q11.py @@ -16,3 +16,7 @@ tool = "pyarrow" fn = IO_FUNCS[tool] print(q11.query(fn(nation), fn(partsupp), fn(supplier))) + +tool = "dask" +fn = IO_FUNCS[tool] +print(q11.query(fn(nation), fn(partsupp), fn(supplier)).compute()) diff --git a/tpch/execute/q12.py b/tpch/execute/q12.py index 0cdc0378b..3c3a70c62 100644 --- a/tpch/execute/q12.py +++ b/tpch/execute/q12.py @@ -15,3 +15,7 @@ tool = "pyarrow" fn = IO_FUNCS[tool] print(q12.query(fn(line_item), fn(orders))) + +tool = "dask" +fn = IO_FUNCS[tool] +print(q12.query(fn(line_item), fn(orders)).compute()) diff --git a/tpch/execute/q13.py b/tpch/execute/q13.py index b5e6c8bbe..2fdda5bd3 100644 --- a/tpch/execute/q13.py +++ b/tpch/execute/q13.py @@ -15,3 +15,7 @@ tool = "pyarrow" fn = IO_FUNCS[tool] print(q13.query(fn(customer), fn(orders))) + +tool = "dask" +fn = IO_FUNCS[tool] +print(q13.query(fn(customer), fn(orders)).compute()) diff --git a/tpch/execute/q14.py b/tpch/execute/q14.py index 1a89dbbbe..dfd54056e 100644 --- a/tpch/execute/q14.py +++ b/tpch/execute/q14.py @@ -15,3 +15,7 @@ tool = "pyarrow" fn = IO_FUNCS[tool] print(q14.query(fn(line_item), fn(part))) + +tool = "dask" +fn = IO_FUNCS[tool] +print(q14.query(fn(line_item), fn(part)).compute()) diff --git a/tpch/execute/q15.py b/tpch/execute/q15.py index ac858841d..86a03b0a0 100644 --- a/tpch/execute/q15.py +++ b/tpch/execute/q15.py @@ -15,3 +15,7 @@ tool = "pyarrow" fn = IO_FUNCS[tool] print(q15.query(fn(lineitem), fn(supplier))) + +tool = "dask" +fn = IO_FUNCS[tool] +print(q15.query(fn(lineitem), fn(supplier)).compute()) diff --git a/tpch/execute/q16.py b/tpch/execute/q16.py index 7fa6c72b0..6a70279d0 100644 --- a/tpch/execute/q16.py +++ b/tpch/execute/q16.py @@ -16,3 +16,7 @@ tool = "pyarrow" fn = IO_FUNCS[tool] print(q16.query(fn(part), fn(partsupp), fn(supplier))) + +tool = "dask" +fn = IO_FUNCS[tool] +print(q16.query(fn(part), fn(partsupp), fn(supplier)).compute()) diff --git a/tpch/execute/q17.py b/tpch/execute/q17.py index 8eefb92dc..43ef4f8b1 100644 --- a/tpch/execute/q17.py +++ b/tpch/execute/q17.py @@ -15,3 +15,7 @@ tool = "pyarrow" fn = IO_FUNCS[tool] print(q17.query(fn(lineitem), fn(part))) + +tool = "dask" +fn = IO_FUNCS[tool] +print(q17.query(fn(lineitem), fn(part)).compute()) diff --git a/tpch/execute/q18.py b/tpch/execute/q18.py index fdd50c095..c7e5b7954 100644 --- a/tpch/execute/q18.py +++ b/tpch/execute/q18.py @@ -16,3 +16,7 @@ tool = "pyarrow" fn = IO_FUNCS[tool] print(q18.query(fn(customer), fn(lineitem), fn(orders))) + +tool = "dask" +fn = IO_FUNCS[tool] +print(q18.query(fn(customer), fn(lineitem), fn(orders)).compute()) diff --git a/tpch/execute/q19.py b/tpch/execute/q19.py index e1dff3eb5..60f91b052 100644 --- a/tpch/execute/q19.py +++ b/tpch/execute/q19.py @@ -12,3 +12,6 @@ fn = IO_FUNCS["pyarrow"] print(q19.query(fn(lineitem), fn(part))) + +fn = IO_FUNCS["dask"] +print(q19.query(fn(lineitem), fn(part)).compute()) diff --git a/tpch/execute/q20.py b/tpch/execute/q20.py index d15f8c85f..3984b7580 100644 --- a/tpch/execute/q20.py +++ b/tpch/execute/q20.py @@ -15,3 +15,6 @@ fn = IO_FUNCS["pyarrow"] print(q20.query(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(supplier))) + +fn = IO_FUNCS["dask"] +print(q20.query(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(supplier)).compute()) diff --git a/tpch/execute/q21.py b/tpch/execute/q21.py index 9940e6232..7cf772d8e 100644 --- a/tpch/execute/q21.py +++ b/tpch/execute/q21.py @@ -14,3 +14,6 @@ fn = IO_FUNCS["pyarrow"] print(q21.query(fn(lineitem), fn(nation), fn(orders), fn(supplier))) + +fn = IO_FUNCS["dask"] +print(q21.query(fn(lineitem), fn(nation), fn(orders), fn(supplier)).compute()) diff --git a/tpch/execute/q22.py b/tpch/execute/q22.py index 3b3fe523f..a2bb1e76d 100644 --- a/tpch/execute/q22.py +++ b/tpch/execute/q22.py @@ -15,3 +15,7 @@ tool = "pyarrow" fn = IO_FUNCS[tool] print(q22.query(fn(customer), fn(orders))) + +tool = "dask" +fn = IO_FUNCS[tool] +print(q22.query(fn(customer), fn(orders)).compute()) diff --git a/tpch/execute/q3.py b/tpch/execute/q3.py index f836fae27..d6b9302cc 100644 --- a/tpch/execute/q3.py +++ b/tpch/execute/q3.py @@ -16,3 +16,7 @@ tool = "pyarrow" fn = IO_FUNCS[tool] print(q3.query(fn(customer), fn(lineitem), fn(orders))) + +tool = "dask" +fn = IO_FUNCS[tool] +print(q3.query(fn(customer), fn(lineitem), fn(orders)).compute()) diff --git a/tpch/execute/q4.py b/tpch/execute/q4.py index ca60f38ee..5645574f8 100644 --- a/tpch/execute/q4.py +++ b/tpch/execute/q4.py @@ -15,3 +15,7 @@ tool = "pyarrow" fn = IO_FUNCS[tool] print(q4.query(fn(line_item), fn(orders))) + +tool = "dask" +fn = IO_FUNCS[tool] +print(q4.query(fn(line_item), fn(orders)).compute()) diff --git a/tpch/execute/q5.py b/tpch/execute/q5.py index c343fea5d..dcc61027b 100644 --- a/tpch/execute/q5.py +++ b/tpch/execute/q5.py @@ -31,3 +31,11 @@ fn(region), fn(nation), fn(customer), fn(line_item), fn(orders), fn(supplier) ) ) + +tool = "dask" +fn = IO_FUNCS[tool] +print( + q5.query( + fn(region), fn(nation), fn(customer), fn(line_item), fn(orders), fn(supplier) + ).compute() +) diff --git a/tpch/execute/q6.py b/tpch/execute/q6.py index eebf3f864..154964ff4 100644 --- a/tpch/execute/q6.py +++ b/tpch/execute/q6.py @@ -14,3 +14,7 @@ tool = "pyarrow" fn = IO_FUNCS[tool] print(q6.query(fn(lineitem))) + +tool = "dask" +fn = IO_FUNCS[tool] +print(q6.query(fn(lineitem)).compute()) diff --git a/tpch/execute/q7.py b/tpch/execute/q7.py index c59f82ce7..a08d5641c 100644 --- a/tpch/execute/q7.py +++ b/tpch/execute/q7.py @@ -20,3 +20,9 @@ tool = "pyarrow" fn = IO_FUNCS[tool] print(q7.query(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))) + +tool = "dask" +fn = IO_FUNCS[tool] +print( + q7.query(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier)).compute() +) diff --git a/tpch/execute/q8.py b/tpch/execute/q8.py index 902a34e70..a76a8051f 100644 --- a/tpch/execute/q8.py +++ b/tpch/execute/q8.py @@ -51,3 +51,17 @@ fn(region), ) ) + +tool = "dask" +fn = IO_FUNCS[tool] +print( + q8.query( + fn(part), + fn(supplier), + fn(lineitem), + fn(orders), + fn(customer), + fn(nation), + fn(region), + ).compute() +) diff --git a/tpch/execute/q9.py b/tpch/execute/q9.py index 44d4154aa..14230af64 100644 --- a/tpch/execute/q9.py +++ b/tpch/execute/q9.py @@ -27,3 +27,11 @@ print( q9.query(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(orders), fn(supplier)) ) + +tool = "dask" +fn = IO_FUNCS[tool] +print( + q9.query( + fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(orders), fn(supplier) + ).compute() +) diff --git a/tpch/generate_data.py b/tpch/generate_data.py index 4d5695dcf..5fd73b1f7 100644 --- a/tpch/generate_data.py +++ b/tpch/generate_data.py @@ -10,7 +10,7 @@ con = duckdb.connect(database=":memory:") con.execute("INSTALL tpch; LOAD tpch") -con.execute("CALL dbgen(sf=1)") +con.execute("CALL dbgen(sf=.5)") tables = [ "lineitem", "customer", diff --git a/tpch/queries/q20.py b/tpch/queries/q20.py index d9014f7b8..b0dabb29e 100644 --- a/tpch/queries/q20.py +++ b/tpch/queries/q20.py @@ -28,7 +28,8 @@ def query( return ( part_ds.filter(nw.col("p_name").str.starts_with(var4)) - .select(nw.col("p_partkey").unique()) + .select("p_partkey") + .unique("p_partkey") .join(partsupp_ds, left_on="p_partkey", right_on="ps_partkey") .join( query1, @@ -36,7 +37,8 @@ def query( right_on=["l_suppkey", "l_partkey"], ) .filter(nw.col("ps_availqty") > nw.col("sum_quantity")) - .select(nw.col("ps_suppkey").unique()) + .select("ps_suppkey") + .unique("ps_suppkey") .join(query3, left_on="ps_suppkey", right_on="s_suppkey") .select("s_name", "s_address") .sort("s_name") diff --git a/tpch/queries/q22.py b/tpch/queries/q22.py index 4738c6fd3..2e0973227 100644 --- a/tpch/queries/q22.py +++ b/tpch/queries/q22.py @@ -14,8 +14,10 @@ def query(customer_ds: FrameT, orders_ds: FrameT) -> FrameT: nw.col("c_acctbal").mean().alias("avg_acctbal") ) - q3 = orders_ds.select(nw.col("o_custkey").unique()).with_columns( - nw.col("o_custkey").alias("c_custkey") + q3 = ( + orders_ds.select("o_custkey") + .unique("o_custkey") + .with_columns(nw.col("o_custkey").alias("c_custkey")) ) return ( diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index ec599def5..1bf1f086e 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -9,15 +9,30 @@ ret = 0 +NAMESPACES = {"dt", "str", "cat", "name"} +EXPR_ONLY_METHODS = {"over"} +SERIES_ONLY_METHODS = { + "to_arrow", + "to_dummies", + "to_pandas", + "to_list", + "to_numpy", + "dtype", + "name", + "shape", + "to_frame", + "is_empty", + "is_sorted", + "value_counts", + "zip_with", + "item", + "scatter", +} + # TODO(Unassigned): make dtypes reference page as well files = {remove_suffix(i, ".py") for i in os.listdir("narwhals")} top_level_functions = [ - i - for i in nw.__dir__() - if not i[0].isupper() - and i[0] != "_" - and i not in files - and i not in {"annotations", "DataFrame", "LazyFrame", "Series"} + i for i in nw.__dir__() if not i[0].isupper() and i[0] != "_" and i not in files ] with open("docs/api-reference/narwhals.md") as fd: content = fd.read() @@ -89,11 +104,7 @@ for i in content.splitlines() if i.startswith(" - ") and not i.startswith(" - _") ] -if ( - missing := set(top_level_functions) - .difference(documented) - .difference({"dt", "str", "cat", "name"}) -): +if missing := set(top_level_functions).difference(documented).difference(NAMESPACES): print("Series: not documented") # noqa: T201 print(missing) # noqa: T201 ret = 1 @@ -112,11 +123,7 @@ for i in content.splitlines() if i.startswith(" - ") ] -if ( - missing := set(top_level_functions) - .difference(documented) - .difference({"cat", "str", "dt", "name"}) -): +if missing := set(top_level_functions).difference(documented).difference(NAMESPACES): print("Expr: not documented") # noqa: T201 print(missing) # noqa: T201 ret = 1 @@ -139,33 +146,11 @@ if not i[0].isupper() and i[0] != "_" ] -if missing := set(expr).difference(series).difference({"over"}): +if missing := set(expr).difference(series).difference(EXPR_ONLY_METHODS): print("In expr but not in series") # noqa: T201 print(missing) # noqa: T201 ret = 1 -if ( - extra := set(series) - .difference(expr) - .difference( - { - "to_arrow", - "to_dummies", - "to_pandas", - "to_list", - "to_numpy", - "dtype", - "name", - "shape", - "to_frame", - "is_empty", - "is_sorted", - "value_counts", - "zip_with", - "item", - "scatter", - } - ) -): +if extra := set(series).difference(expr).difference(SERIES_ONLY_METHODS): print("in series but not in expr") # noqa: T201 print(extra) # noqa: T201 ret = 1