Skip to content

Commit

Permalink
feat: add seed to Expr and Series sample
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Sep 14, 2024
1 parent 586f8d7 commit c4637ba
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 20 deletions.
10 changes: 8 additions & 2 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,18 @@ def arg_true(self) -> Self:
def sample(
self: Self,
n: int | None = None,
fraction: float | None = None,
*,
fraction: float | None = None,
with_replacement: bool = False,
seed: int | None = None,
) -> Self:
return reuse_series_implementation(
self, "sample", n=n, fraction=fraction, with_replacement=with_replacement
self,
"sample",
n=n,
fraction=fraction,
with_replacement=with_replacement,
seed=seed,
)

def fill_null(self: Self, value: Any) -> Self:
Expand Down
5 changes: 4 additions & 1 deletion narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,10 @@ def zip_with(self: Self, mask: Self, other: Self) -> Self:
def sample(
self: Self,
n: int | None = None,
fraction: float | None = None,
*,
fraction: float | None = None,
with_replacement: bool = False,
seed: int | None = None,
) -> Self:
import numpy as np # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import()
Expand All @@ -547,8 +548,10 @@ def sample(
if n is None and fraction is not None:
n = int(num_rows * fraction)

np.random.seed(seed)
idx = np.arange(0, num_rows)
mask = np.random.choice(idx, size=n, replace=with_replacement)

return self._from_native_series(pc.take(ser, mask))

def fill_null(self: Self, value: Any) -> Self:
Expand Down
12 changes: 9 additions & 3 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,20 @@ def shift(self, n: int) -> Self:
return reuse_series_implementation(self, "shift", n=n)

def sample(
self,
self: Self,
n: int | None = None,
fraction: float | None = None,
*,
fraction: float | None = None,
with_replacement: bool = False,
seed: int | None = None,
) -> Self:
return reuse_series_implementation(
self, "sample", n=n, fraction=fraction, with_replacement=with_replacement
self,
"sample",
n=n,
fraction=fraction,
with_replacement=with_replacement,
seed=seed,
)

def alias(self, name: str) -> Self:
Expand Down
9 changes: 5 additions & 4 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,15 +436,16 @@ def n_unique(self) -> int:
return ser.nunique(dropna=False) # type: ignore[no-any-return]

def sample(
self,
self: Self,
n: int | None = None,
fraction: float | None = None,
*,
fraction: float | None = None,
with_replacement: bool = False,
) -> PandasLikeSeries:
seed: int | None = None,
) -> Self:
ser = self._native_series
return self._from_native_series(
ser.sample(n=n, frac=fraction, replace=with_replacement)
ser.sample(n=n, frac=fraction, replace=with_replacement, random_state=seed)
)

def abs(self) -> PandasLikeSeries:
Expand Down
11 changes: 6 additions & 5 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,21 +1228,22 @@ def drop_nulls(self) -> Self:
return self.__class__(lambda plx: self._call(plx).drop_nulls())

def sample(
self,
self: Self,
n: int | None = None,
fraction: float | None = None,
*,
fraction: float | None = None,
with_replacement: bool = False,
seed: int | None = None,
) -> Self:
"""
Sample randomly from this expression.
Arguments:
n: Number of items to return. Cannot be used with fraction.
fraction: Fraction of items to return. Cannot be used with n.
with_replacement: Allow values to be sampled more than once.
seed: Seed for the random number generator. If set to None (default), a random
seed is generated for each sample operation.
Examples:
>>> import narwhals as nw
Expand Down Expand Up @@ -1279,7 +1280,7 @@ def sample(
"""
return self.__class__(
lambda plx: self._call(plx).sample(
n, fraction=fraction, with_replacement=with_replacement
n, fraction=fraction, with_replacement=with_replacement, seed=seed
)
)

Expand Down
11 changes: 6 additions & 5 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,21 +1076,22 @@ def shift(self, n: int) -> Self:
return self._from_compliant_series(self._compliant_series.shift(n))

def sample(
self,
self: Self,
n: int | None = None,
fraction: float | None = None,
*,
fraction: float | None = None,
with_replacement: bool = False,
seed: int | None = None,
) -> Self:
"""
Sample randomly from this Series.
Arguments:
n: Number of items to return. Cannot be used with fraction.
fraction: Fraction of items to return. Cannot be used with n.
with_replacement: Allow values to be sampled more than once.
seed: Seed for the random number generator. If set to None (default), a random
seed is generated for each sample operation.
Notes:
The `sample` method returns a Series with a specified number of
Expand Down Expand Up @@ -1131,7 +1132,7 @@ def sample(
"""
return self._from_compliant_series(
self._compliant_series.sample(
n=n, fraction=fraction, with_replacement=with_replacement
n=n, fraction=fraction, with_replacement=with_replacement, seed=seed
)
)

Expand Down
35 changes: 35 additions & 0 deletions tests/expr_and_series/sample_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import compare_dicts


def test_expr_sample(constructor: Constructor, request: pytest.FixtureRequest) -> None:
Expand Down Expand Up @@ -32,3 +33,37 @@ def test_expr_sample_fraction(
result_series = df.collect()["a"].sample(fraction=0.1).shape
expected_series = (3,)
assert result_series == expected_series


def test_sample_with_seed(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

size, n = 100, 10
df = nw.from_native(constructor({"a": list(range(size))})).lazy()
expected = {"res1": [True], "res2": [False]}
result = (
df.select(
seed1=nw.col("a").sample(n=n, seed=123),
seed2=nw.col("a").sample(n=n, seed=123),
seed3=nw.col("a").sample(n=n, seed=42),
)
.select(
res1=(nw.col("seed1") == nw.col("seed2")).all(),
res2=(nw.col("seed1") == nw.col("seed3")).all(),
)
.collect()
)

compare_dicts(result, expected)

series = df.collect()["a"]
seed1 = series.sample(n=n, seed=123)
seed2 = series.sample(n=n, seed=123)
seed3 = series.sample(n=n, seed=42)

compare_dicts(
{"res1": [(seed1 == seed2).all()], "res2": [(seed1 == seed3).all()]}, expected
)

0 comments on commit c4637ba

Please sign in to comment.