From 7fff1e8ef52d8a38f8297c0314115414b45643ea Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 8 Apr 2024 18:44:54 +0100 Subject: [PATCH 1/4] feat: support weekend argument in business_day_count --- crates/polars-ops/src/series/ops/business.rs | 14 ++-- .../src/dsl/function_expr/business.rs | 12 ++-- .../polars-plan/src/dsl/functions/business.rs | 4 +- py-polars/polars/functions/business.py | 65 +++++++++++++++++-- py-polars/polars/type_aliases.py | 9 +++ py-polars/src/functions/business.rs | 4 +- .../time_series/test_business_day_count.py | 16 ++++- .../business/test_business_day_count.py | 36 ++++++++++ 8 files changed, 138 insertions(+), 22 deletions(-) diff --git a/crates/polars-ops/src/series/ops/business.rs b/crates/polars-ops/src/series/ops/business.rs index 5b792453b5c4..3b975ec778ec 100644 --- a/crates/polars-ops/src/series/ops/business.rs +++ b/crates/polars-ops/src/series/ops/business.rs @@ -2,12 +2,18 @@ use polars_core::prelude::arity::binary_elementwise_values; use polars_core::prelude::*; /// Count the number of business days between `start` and `end`, excluding `end`. -pub fn business_day_count(start: &Series, end: &Series) -> PolarsResult { +/// +/// # Arguments +/// - `start`: Series holding start dates. +/// - `end`: Series holding end dates. +/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day. +pub fn business_day_count( + start: &Series, + end: &Series, + week_mask: [bool; 7], +) -> PolarsResult { let start_dates = start.date()?; let end_dates = end.date()?; - - // TODO: support customising weekdays - let week_mask: [bool; 7] = [true, true, true, true, true, false, false]; let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32; let out = match (start_dates.len(), end_dates.len()) { diff --git a/crates/polars-plan/src/dsl/function_expr/business.rs b/crates/polars-plan/src/dsl/function_expr/business.rs index f9a38b1165cc..745dcfdff8f5 100644 --- a/crates/polars-plan/src/dsl/function_expr/business.rs +++ b/crates/polars-plan/src/dsl/function_expr/business.rs @@ -12,7 +12,7 @@ use crate::prelude::SeriesUdf; #[derive(Clone, PartialEq, Debug, Eq, Hash)] pub enum BusinessFunction { #[cfg(feature = "business")] - BusinessDayCount, + BusinessDayCount { week_mask: [bool; 7] }, } impl Display for BusinessFunction { @@ -20,7 +20,7 @@ impl Display for BusinessFunction { use BusinessFunction::*; let s = match self { #[cfg(feature = "business")] - &BusinessDayCount => "business_day_count", + &BusinessDayCount { .. } => "business_day_count", }; write!(f, "{s}") } @@ -30,16 +30,16 @@ impl From for SpecialEq> { use BusinessFunction::*; match func { #[cfg(feature = "business")] - BusinessDayCount => { - map_as_slice!(business_day_count) + BusinessDayCount { week_mask } => { + map_as_slice!(business_day_count, week_mask) }, } } } #[cfg(feature = "business")] -pub(super) fn business_day_count(s: &[Series]) -> PolarsResult { +pub(super) fn business_day_count(s: &[Series], week_mask: [bool; 7]) -> PolarsResult { let start = &s[0]; let end = &s[1]; - polars_ops::prelude::business_day_count(start, end) + polars_ops::prelude::business_day_count(start, end, week_mask) } diff --git a/crates/polars-plan/src/dsl/functions/business.rs b/crates/polars-plan/src/dsl/functions/business.rs index 4bfdcc0b20cc..0a0210ced57f 100644 --- a/crates/polars-plan/src/dsl/functions/business.rs +++ b/crates/polars-plan/src/dsl/functions/business.rs @@ -1,12 +1,12 @@ use super::*; #[cfg(feature = "dtype-date")] -pub fn business_day_count(start: Expr, end: Expr) -> Expr { +pub fn business_day_count(start: Expr, end: Expr, week_mask: [bool; 7]) -> Expr { let input = vec![start, end]; Expr::Function { input, - function: FunctionExpr::Business(BusinessFunction::BusinessDayCount {}), + function: FunctionExpr::Business(BusinessFunction::BusinessDayCount { week_mask }), options: FunctionOptions { allow_rename: true, ..Default::default() diff --git a/py-polars/polars/functions/business.py b/py-polars/polars/functions/business.py index ae5791fde2a6..f65292f10ef4 100644 --- a/py-polars/polars/functions/business.py +++ b/py-polars/polars/functions/business.py @@ -1,7 +1,7 @@ from __future__ import annotations import contextlib -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable from polars._utils.parse_expr_input import parse_as_expression from polars._utils.wrap import wrap_expr @@ -13,25 +13,61 @@ from datetime import date from polars import Expr - from polars.type_aliases import IntoExprColumn + from polars.type_aliases import DayOfWeek, IntoExprColumn + +DAY_NAMES = ( + "Mon", + "Tue", + "Wed", + "Thu", + "Fri", + "Sat", + "Sun", +) + + +def _make_week_mask( + weekend: Iterable[str] | None, +) -> tuple[bool, ...]: + if weekend is None: + return tuple([True] * 7) + if isinstance(weekend, str): + weekend_set = {weekend} + else: + weekend_set = set(weekend) + for day in weekend_set: + if day not in DAY_NAMES: + msg = f"Expected one of {DAY_NAMES}, got: {day}" + raise ValueError(msg) + return tuple( + [ + False if v in weekend else True # noqa: SIM211 + for v in DAY_NAMES + ] + ) def business_day_count( start: date | IntoExprColumn, end: date | IntoExprColumn, + weekend: DayOfWeek | Iterable[DayOfWeek] | None = ("Sat", "Sun"), ) -> Expr: """ Count the number of business days between `start` and `end` (not including `end`). - By default, Saturday and Sunday are excluded. The ability to - customise week mask and holidays is not yet implemented. - Parameters ---------- start Start dates. end End dates. + weekend + Which days of the week to exclude. The default is `('Sat', 'Sun')`, but you + can also pass, for example, `weekend=('Fri', 'Sat')`, `weekend='Sun'`, + or `weekend=None`. + + Allowed values in the tuple are 'Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', + and 'Sun'. Returns ------- @@ -62,7 +98,24 @@ def business_day_count( Note how the two "count" columns differ due to the weekend (2020-01-04 - 2020-01-05) not being counted by `business_day_count`. + + You can pass a custom weekend - for example, if you only take Sunday off: + + >>> df.with_columns( + ... total_day_count=(pl.col("end") - pl.col("start")).dt.total_days(), + ... business_day_count=pl.business_day_count("start", "end", weekend="Sun"), + ... ) + shape: (2, 4) + ┌────────────┬────────────┬─────────────────┬────────────────────┐ + │ start ┆ end ┆ total_day_count ┆ business_day_count │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ date ┆ date ┆ i64 ┆ i32 │ + ╞════════════╪════════════╪═════════════════╪════════════════════╡ + │ 2020-01-01 ┆ 2020-01-02 ┆ 1 ┆ 1 │ + │ 2020-01-02 ┆ 2020-01-10 ┆ 8 ┆ 7 │ + └────────────┴────────────┴─────────────────┴────────────────────┘ """ start_pyexpr = parse_as_expression(start) end_pyexpr = parse_as_expression(end) - return wrap_expr(plr.business_day_count(start_pyexpr, end_pyexpr)) + week_mask = _make_week_mask(weekend) + return wrap_expr(plr.business_day_count(start_pyexpr, end_pyexpr, week_mask)) diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index 2443ce4f64ad..bdc9e741e71b 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -158,6 +158,15 @@ "horizontal", "align", ] +DayOfWeek = Literal[ + "Mon", + "Tue", + "Wed", + "Thu", + "Fri", + "Sat", + "Sun", +] EpochTimeUnit = Literal["ns", "us", "ms", "s", "d"] Orientation: TypeAlias = Literal["col", "row"] SearchSortedSide: TypeAlias = Literal["any", "left", "right"] diff --git a/py-polars/src/functions/business.rs b/py-polars/src/functions/business.rs index 246f902b895a..0ca6ec058d4a 100644 --- a/py-polars/src/functions/business.rs +++ b/py-polars/src/functions/business.rs @@ -4,8 +4,8 @@ use pyo3::prelude::*; use crate::PyExpr; #[pyfunction] -pub fn business_day_count(start: PyExpr, end: PyExpr) -> PyExpr { +pub fn business_day_count(start: PyExpr, end: PyExpr, week_mask: [bool; 7]) -> PyExpr { let start = start.inner; let end = end.inner; - dsl::business_day_count(start, end).into() + dsl::business_day_count(start, end, week_mask).into() } diff --git a/py-polars/tests/parametric/time_series/test_business_day_count.py b/py-polars/tests/parametric/time_series/test_business_day_count.py index 0cb1bf95df33..04a2553e7aaf 100644 --- a/py-polars/tests/parametric/time_series/test_business_day_count.py +++ b/py-polars/tests/parametric/time_series/test_business_day_count.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime as dt +from typing import TYPE_CHECKING import hypothesis.strategies as st import numpy as np @@ -8,22 +9,33 @@ import polars as pl from polars._utils.various import parse_version +from polars.functions.business import _make_week_mask + +if TYPE_CHECKING: + from polars.type_aliases import DayOfWeek @given( start=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), end=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), + weekend=st.lists( + st.sampled_from(["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]), + min_size=0, + max_size=6, + unique=True, + ), ) def test_against_np_busday_count( start: dt.date, end: dt.date, + weekend: list[DayOfWeek], ) -> None: result = ( pl.DataFrame({"start": [start], "end": [end]}) - .select(n=pl.business_day_count("start", "end"))["n"] + .select(n=pl.business_day_count("start", "end", weekend=weekend))["n"] .item() ) - expected = np.busday_count(start, end) + expected = np.busday_count(start, end, weekmask=_make_week_mask(weekend)) if start > end and parse_version(np.__version__) < parse_version("1.25"): # Bug in old versions of numpy reject() diff --git a/py-polars/tests/unit/functions/business/test_business_day_count.py b/py-polars/tests/unit/functions/business/test_business_day_count.py index 74befbd3268b..6feced25e2bb 100644 --- a/py-polars/tests/unit/functions/business/test_business_day_count.py +++ b/py-polars/tests/unit/functions/business/test_business_day_count.py @@ -1,5 +1,7 @@ from datetime import date +import pytest + import polars as pl from polars.testing import assert_series_equal @@ -50,6 +52,40 @@ def test_business_day_count() -> None: assert_series_equal(result, expected) +def test_business_day_count_w_weekend() -> None: + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "end": [date(2020, 1, 2), date(2020, 1, 10)], + } + ) + result = df.select( + business_day_count=pl.business_day_count("start", "end", weekend="Sun"), + )["business_day_count"] + expected = pl.Series("business_day_count", [1, 7], pl.Int32) + assert_series_equal(result, expected) + result = df.select( + business_day_count=pl.business_day_count( + "start", "end", weekend=("Thu", "Fri", "Sat") + ), + )["business_day_count"] + expected = pl.Series("business_day_count", [1, 4], pl.Int32) + assert_series_equal(result, expected) + result = df.select( + business_day_count=pl.business_day_count("start", "end", weekend=None), + )["business_day_count"] + expected = pl.Series("business_day_count", [1, 8], pl.Int32) + assert_series_equal(result, expected) + + +def test_business_day_count_w_weekend_invalid() -> None: + msg = r"Expected one of \('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'\), got: cabbage" + with pytest.raises(ValueError, match=msg): + pl.business_day_count("start", "end", weekend="cabbage") # type: ignore[arg-type] + with pytest.raises(ValueError, match=msg): + pl.business_day_count("start", "end", weekend=("Sat", "cabbage")) # type: ignore[arg-type] + + def test_business_day_count_schema() -> None: lf = pl.LazyFrame( { From 81cd34a2760f47612ac40ed790bc9ac94e673dcb Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 9 Apr 2024 08:47:41 +0100 Subject: [PATCH 2/4] cache day names --- py-polars/polars/functions/business.py | 28 +++++++++++++++----------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/py-polars/polars/functions/business.py b/py-polars/polars/functions/business.py index f65292f10ef4..c4d8d73ab24d 100644 --- a/py-polars/polars/functions/business.py +++ b/py-polars/polars/functions/business.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import functools from typing import TYPE_CHECKING, Iterable from polars._utils.parse_expr_input import parse_as_expression @@ -15,15 +16,18 @@ from polars import Expr from polars.type_aliases import DayOfWeek, IntoExprColumn -DAY_NAMES = ( - "Mon", - "Tue", - "Wed", - "Thu", - "Fri", - "Sat", - "Sun", -) + +@functools.lru_cache +def _day_names() -> tuple[str, ...]: + return ( + "Mon", + "Tue", + "Wed", + "Thu", + "Fri", + "Sat", + "Sun", + ) def _make_week_mask( @@ -36,13 +40,13 @@ def _make_week_mask( else: weekend_set = set(weekend) for day in weekend_set: - if day not in DAY_NAMES: - msg = f"Expected one of {DAY_NAMES}, got: {day}" + if day not in _day_names(): + msg = f"Expected one of {_day_names()}, got: {day}" raise ValueError(msg) return tuple( [ False if v in weekend else True # noqa: SIM211 - for v in DAY_NAMES + for v in _day_names() ] ) From fe61fc28fd06a379fc7cacd4e37fdefde52f0032 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 9 Apr 2024 09:12:37 +0100 Subject: [PATCH 3/4] extra test --- py-polars/polars/functions/business.py | 6 +++--- .../unit/functions/business/test_business_day_count.py | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/py-polars/polars/functions/business.py b/py-polars/polars/functions/business.py index c4d8d73ab24d..df23d9700eb6 100644 --- a/py-polars/polars/functions/business.py +++ b/py-polars/polars/functions/business.py @@ -36,10 +36,10 @@ def _make_week_mask( if weekend is None: return tuple([True] * 7) if isinstance(weekend, str): - weekend_set = {weekend} + weekend = {weekend} else: - weekend_set = set(weekend) - for day in weekend_set: + weekend = set(weekend) + for day in weekend: if day not in _day_names(): msg = f"Expected one of {_day_names()}, got: {day}" raise ValueError(msg) diff --git a/py-polars/tests/unit/functions/business/test_business_day_count.py b/py-polars/tests/unit/functions/business/test_business_day_count.py index 6feced25e2bb..bf08f4ac98fb 100644 --- a/py-polars/tests/unit/functions/business/test_business_day_count.py +++ b/py-polars/tests/unit/functions/business/test_business_day_count.py @@ -64,6 +64,13 @@ def test_business_day_count_w_weekend() -> None: )["business_day_count"] expected = pl.Series("business_day_count", [1, 7], pl.Int32) assert_series_equal(result, expected) + + result = df.select( + business_day_count=pl.business_day_count("start", "end", weekend=("Sun",)), + )["business_day_count"] + expected = pl.Series("business_day_count", [1, 7], pl.Int32) + assert_series_equal(result, expected) + result = df.select( business_day_count=pl.business_day_count( "start", "end", weekend=("Thu", "Fri", "Sat") From 763537920327166c0aef230ed99b02eecce0ecc7 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 9 Apr 2024 20:48:44 +0100 Subject: [PATCH 4/4] keep it simple --- crates/polars-ops/src/series/ops/business.rs | 3 ++ py-polars/polars/functions/business.py | 54 +++---------------- py-polars/polars/type_aliases.py | 9 ---- .../time_series/test_business_day_count.py | 23 ++++---- .../business/test_business_day_count.py | 38 +++++++------ 5 files changed, 38 insertions(+), 89 deletions(-) diff --git a/crates/polars-ops/src/series/ops/business.rs b/crates/polars-ops/src/series/ops/business.rs index 3b975ec778ec..115ccf8ae389 100644 --- a/crates/polars-ops/src/series/ops/business.rs +++ b/crates/polars-ops/src/series/ops/business.rs @@ -12,6 +12,9 @@ pub fn business_day_count( end: &Series, week_mask: [bool; 7], ) -> PolarsResult { + if !week_mask.iter().any(|&x| x) { + polars_bail!(ComputeError:"`week_mask` must have at least one business day"); + } let start_dates = start.date()?; let end_dates = end.date()?; let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32; diff --git a/py-polars/polars/functions/business.py b/py-polars/polars/functions/business.py index df23d9700eb6..125bda15113e 100644 --- a/py-polars/polars/functions/business.py +++ b/py-polars/polars/functions/business.py @@ -1,7 +1,6 @@ from __future__ import annotations import contextlib -import functools from typing import TYPE_CHECKING, Iterable from polars._utils.parse_expr_input import parse_as_expression @@ -14,47 +13,13 @@ from datetime import date from polars import Expr - from polars.type_aliases import DayOfWeek, IntoExprColumn - - -@functools.lru_cache -def _day_names() -> tuple[str, ...]: - return ( - "Mon", - "Tue", - "Wed", - "Thu", - "Fri", - "Sat", - "Sun", - ) - - -def _make_week_mask( - weekend: Iterable[str] | None, -) -> tuple[bool, ...]: - if weekend is None: - return tuple([True] * 7) - if isinstance(weekend, str): - weekend = {weekend} - else: - weekend = set(weekend) - for day in weekend: - if day not in _day_names(): - msg = f"Expected one of {_day_names()}, got: {day}" - raise ValueError(msg) - return tuple( - [ - False if v in weekend else True # noqa: SIM211 - for v in _day_names() - ] - ) + from polars.type_aliases import IntoExprColumn def business_day_count( start: date | IntoExprColumn, end: date | IntoExprColumn, - weekend: DayOfWeek | Iterable[DayOfWeek] | None = ("Sat", "Sun"), + week_mask: Iterable[bool] = (True, True, True, True, True, False, False), ) -> Expr: """ Count the number of business days between `start` and `end` (not including `end`). @@ -65,13 +30,10 @@ def business_day_count( Start dates. end End dates. - weekend - Which days of the week to exclude. The default is `('Sat', 'Sun')`, but you - can also pass, for example, `weekend=('Fri', 'Sat')`, `weekend='Sun'`, - or `weekend=None`. - - Allowed values in the tuple are 'Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', - and 'Sun'. + week_mask + Which days of the week to count. The default is Monday to Friday. + If you wanted to count only Monday to Thursday, you would pass + `(True, True, True, True, False, False, False)`. Returns ------- @@ -105,9 +67,10 @@ def business_day_count( You can pass a custom weekend - for example, if you only take Sunday off: + >>> week_mask = (True, True, True, True, True, True, False) >>> df.with_columns( ... total_day_count=(pl.col("end") - pl.col("start")).dt.total_days(), - ... business_day_count=pl.business_day_count("start", "end", weekend="Sun"), + ... business_day_count=pl.business_day_count("start", "end", week_mask), ... ) shape: (2, 4) ┌────────────┬────────────┬─────────────────┬────────────────────┐ @@ -121,5 +84,4 @@ def business_day_count( """ start_pyexpr = parse_as_expression(start) end_pyexpr = parse_as_expression(end) - week_mask = _make_week_mask(weekend) return wrap_expr(plr.business_day_count(start_pyexpr, end_pyexpr, week_mask)) diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index bdc9e741e71b..2443ce4f64ad 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -158,15 +158,6 @@ "horizontal", "align", ] -DayOfWeek = Literal[ - "Mon", - "Tue", - "Wed", - "Thu", - "Fri", - "Sat", - "Sun", -] EpochTimeUnit = Literal["ns", "us", "ms", "s", "d"] Orientation: TypeAlias = Literal["col", "row"] SearchSortedSide: TypeAlias = Literal["any", "left", "right"] diff --git a/py-polars/tests/parametric/time_series/test_business_day_count.py b/py-polars/tests/parametric/time_series/test_business_day_count.py index 04a2553e7aaf..7d9d61fbc5cb 100644 --- a/py-polars/tests/parametric/time_series/test_business_day_count.py +++ b/py-polars/tests/parametric/time_series/test_business_day_count.py @@ -1,41 +1,36 @@ from __future__ import annotations import datetime as dt -from typing import TYPE_CHECKING import hypothesis.strategies as st import numpy as np -from hypothesis import given, reject +from hypothesis import assume, given, reject import polars as pl from polars._utils.various import parse_version -from polars.functions.business import _make_week_mask - -if TYPE_CHECKING: - from polars.type_aliases import DayOfWeek @given( start=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), end=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), - weekend=st.lists( - st.sampled_from(["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]), - min_size=0, - max_size=6, - unique=True, + week_mask=st.lists( + st.sampled_from([True, False]), + min_size=7, + max_size=7, ), ) def test_against_np_busday_count( start: dt.date, end: dt.date, - weekend: list[DayOfWeek], + week_mask: tuple[bool, ...], ) -> None: + assume(any(week_mask)) result = ( pl.DataFrame({"start": [start], "end": [end]}) - .select(n=pl.business_day_count("start", "end", weekend=weekend))["n"] + .select(n=pl.business_day_count("start", "end", week_mask=week_mask))["n"] .item() ) - expected = np.busday_count(start, end, weekmask=_make_week_mask(weekend)) + expected = np.busday_count(start, end, weekmask=week_mask) if start > end and parse_version(np.__version__) < parse_version("1.25"): # Bug in old versions of numpy reject() diff --git a/py-polars/tests/unit/functions/business/test_business_day_count.py b/py-polars/tests/unit/functions/business/test_business_day_count.py index bf08f4ac98fb..13a1a05dbb7b 100644 --- a/py-polars/tests/unit/functions/business/test_business_day_count.py +++ b/py-polars/tests/unit/functions/business/test_business_day_count.py @@ -52,7 +52,7 @@ def test_business_day_count() -> None: assert_series_equal(result, expected) -def test_business_day_count_w_weekend() -> None: +def test_business_day_count_w_week_mask() -> None: df = pl.DataFrame( { "start": [date(2020, 1, 1), date(2020, 1, 2)], @@ -60,37 +60,35 @@ def test_business_day_count_w_weekend() -> None: } ) result = df.select( - business_day_count=pl.business_day_count("start", "end", weekend="Sun"), - )["business_day_count"] - expected = pl.Series("business_day_count", [1, 7], pl.Int32) - assert_series_equal(result, expected) - - result = df.select( - business_day_count=pl.business_day_count("start", "end", weekend=("Sun",)), + business_day_count=pl.business_day_count( + "start", "end", week_mask=(True, True, True, True, True, True, False) + ), )["business_day_count"] expected = pl.Series("business_day_count", [1, 7], pl.Int32) assert_series_equal(result, expected) result = df.select( business_day_count=pl.business_day_count( - "start", "end", weekend=("Thu", "Fri", "Sat") + "start", "end", week_mask=(True, True, True, False, False, False, True) ), )["business_day_count"] expected = pl.Series("business_day_count", [1, 4], pl.Int32) assert_series_equal(result, expected) - result = df.select( - business_day_count=pl.business_day_count("start", "end", weekend=None), - )["business_day_count"] - expected = pl.Series("business_day_count", [1, 8], pl.Int32) - assert_series_equal(result, expected) -def test_business_day_count_w_weekend_invalid() -> None: - msg = r"Expected one of \('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'\), got: cabbage" - with pytest.raises(ValueError, match=msg): - pl.business_day_count("start", "end", weekend="cabbage") # type: ignore[arg-type] - with pytest.raises(ValueError, match=msg): - pl.business_day_count("start", "end", weekend=("Sat", "cabbage")) # type: ignore[arg-type] +def test_business_day_count_w_week_mask_invalid() -> None: + with pytest.raises(ValueError, match=r"expected a sequence of length 7 \(got 2\)"): + pl.business_day_count("start", "end", week_mask=(False, 0)) # type: ignore[arg-type] + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "end": [date(2020, 1, 2), date(2020, 1, 10)], + } + ) + with pytest.raises( + pl.ComputeError, match="`week_mask` must have at least one business day" + ): + df.select(pl.business_day_count("start", "end", week_mask=[False] * 7)) def test_business_day_count_schema() -> None: