Skip to content

Commit

Permalink
keep it simple
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Apr 9, 2024
1 parent fe61fc2 commit 7635379
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 89 deletions.
3 changes: 3 additions & 0 deletions crates/polars-ops/src/series/ops/business.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ pub fn business_day_count(
end: &Series,
week_mask: [bool; 7],
) -> PolarsResult<Series> {
if !week_mask.iter().any(|&x| x) {
polars_bail!(ComputeError:"`week_mask` must have at least one business day");
}
let start_dates = start.date()?;
let end_dates = end.date()?;
let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32;
Expand Down
54 changes: 8 additions & 46 deletions py-polars/polars/functions/business.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import contextlib
import functools
from typing import TYPE_CHECKING, Iterable

from polars._utils.parse_expr_input import parse_as_expression
Expand All @@ -14,47 +13,13 @@
from datetime import date

from polars import Expr
from polars.type_aliases import DayOfWeek, IntoExprColumn


@functools.lru_cache
def _day_names() -> tuple[str, ...]:
return (
"Mon",
"Tue",
"Wed",
"Thu",
"Fri",
"Sat",
"Sun",
)


def _make_week_mask(
weekend: Iterable[str] | None,
) -> tuple[bool, ...]:
if weekend is None:
return tuple([True] * 7)
if isinstance(weekend, str):
weekend = {weekend}
else:
weekend = set(weekend)
for day in weekend:
if day not in _day_names():
msg = f"Expected one of {_day_names()}, got: {day}"
raise ValueError(msg)
return tuple(
[
False if v in weekend else True # noqa: SIM211
for v in _day_names()
]
)
from polars.type_aliases import IntoExprColumn


def business_day_count(
start: date | IntoExprColumn,
end: date | IntoExprColumn,
weekend: DayOfWeek | Iterable[DayOfWeek] | None = ("Sat", "Sun"),
week_mask: Iterable[bool] = (True, True, True, True, True, False, False),
) -> Expr:
"""
Count the number of business days between `start` and `end` (not including `end`).
Expand All @@ -65,13 +30,10 @@ def business_day_count(
Start dates.
end
End dates.
weekend
Which days of the week to exclude. The default is `('Sat', 'Sun')`, but you
can also pass, for example, `weekend=('Fri', 'Sat')`, `weekend='Sun'`,
or `weekend=None`.
Allowed values in the tuple are 'Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat',
and 'Sun'.
week_mask
Which days of the week to count. The default is Monday to Friday.
If you wanted to count only Monday to Thursday, you would pass
`(True, True, True, True, False, False, False)`.
Returns
-------
Expand Down Expand Up @@ -105,9 +67,10 @@ def business_day_count(
You can pass a custom weekend - for example, if you only take Sunday off:
>>> week_mask = (True, True, True, True, True, True, False)
>>> df.with_columns(
... total_day_count=(pl.col("end") - pl.col("start")).dt.total_days(),
... business_day_count=pl.business_day_count("start", "end", weekend="Sun"),
... business_day_count=pl.business_day_count("start", "end", week_mask),
... )
shape: (2, 4)
┌────────────┬────────────┬─────────────────┬────────────────────┐
Expand All @@ -121,5 +84,4 @@ def business_day_count(
"""
start_pyexpr = parse_as_expression(start)
end_pyexpr = parse_as_expression(end)
week_mask = _make_week_mask(weekend)
return wrap_expr(plr.business_day_count(start_pyexpr, end_pyexpr, week_mask))
9 changes: 0 additions & 9 deletions py-polars/polars/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,6 @@
"horizontal",
"align",
]
DayOfWeek = Literal[
"Mon",
"Tue",
"Wed",
"Thu",
"Fri",
"Sat",
"Sun",
]
EpochTimeUnit = Literal["ns", "us", "ms", "s", "d"]
Orientation: TypeAlias = Literal["col", "row"]
SearchSortedSide: TypeAlias = Literal["any", "left", "right"]
Expand Down
23 changes: 9 additions & 14 deletions py-polars/tests/parametric/time_series/test_business_day_count.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,36 @@
from __future__ import annotations

import datetime as dt
from typing import TYPE_CHECKING

import hypothesis.strategies as st
import numpy as np
from hypothesis import given, reject
from hypothesis import assume, given, reject

import polars as pl
from polars._utils.various import parse_version
from polars.functions.business import _make_week_mask

if TYPE_CHECKING:
from polars.type_aliases import DayOfWeek


@given(
start=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)),
end=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)),
weekend=st.lists(
st.sampled_from(["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]),
min_size=0,
max_size=6,
unique=True,
week_mask=st.lists(
st.sampled_from([True, False]),
min_size=7,
max_size=7,
),
)
def test_against_np_busday_count(
start: dt.date,
end: dt.date,
weekend: list[DayOfWeek],
week_mask: tuple[bool, ...],
) -> None:
assume(any(week_mask))
result = (
pl.DataFrame({"start": [start], "end": [end]})
.select(n=pl.business_day_count("start", "end", weekend=weekend))["n"]
.select(n=pl.business_day_count("start", "end", week_mask=week_mask))["n"]
.item()
)
expected = np.busday_count(start, end, weekmask=_make_week_mask(weekend))
expected = np.busday_count(start, end, weekmask=week_mask)
if start > end and parse_version(np.__version__) < parse_version("1.25"):
# Bug in old versions of numpy
reject()
Expand Down
38 changes: 18 additions & 20 deletions py-polars/tests/unit/functions/business/test_business_day_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,45 +52,43 @@ def test_business_day_count() -> None:
assert_series_equal(result, expected)


def test_business_day_count_w_weekend() -> None:
def test_business_day_count_w_week_mask() -> None:
df = pl.DataFrame(
{
"start": [date(2020, 1, 1), date(2020, 1, 2)],
"end": [date(2020, 1, 2), date(2020, 1, 10)],
}
)
result = df.select(
business_day_count=pl.business_day_count("start", "end", weekend="Sun"),
)["business_day_count"]
expected = pl.Series("business_day_count", [1, 7], pl.Int32)
assert_series_equal(result, expected)

result = df.select(
business_day_count=pl.business_day_count("start", "end", weekend=("Sun",)),
business_day_count=pl.business_day_count(
"start", "end", week_mask=(True, True, True, True, True, True, False)
),
)["business_day_count"]
expected = pl.Series("business_day_count", [1, 7], pl.Int32)
assert_series_equal(result, expected)

result = df.select(
business_day_count=pl.business_day_count(
"start", "end", weekend=("Thu", "Fri", "Sat")
"start", "end", week_mask=(True, True, True, False, False, False, True)
),
)["business_day_count"]
expected = pl.Series("business_day_count", [1, 4], pl.Int32)
assert_series_equal(result, expected)
result = df.select(
business_day_count=pl.business_day_count("start", "end", weekend=None),
)["business_day_count"]
expected = pl.Series("business_day_count", [1, 8], pl.Int32)
assert_series_equal(result, expected)


def test_business_day_count_w_weekend_invalid() -> None:
msg = r"Expected one of \('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'\), got: cabbage"
with pytest.raises(ValueError, match=msg):
pl.business_day_count("start", "end", weekend="cabbage") # type: ignore[arg-type]
with pytest.raises(ValueError, match=msg):
pl.business_day_count("start", "end", weekend=("Sat", "cabbage")) # type: ignore[arg-type]
def test_business_day_count_w_week_mask_invalid() -> None:
with pytest.raises(ValueError, match=r"expected a sequence of length 7 \(got 2\)"):
pl.business_day_count("start", "end", week_mask=(False, 0)) # type: ignore[arg-type]
df = pl.DataFrame(
{
"start": [date(2020, 1, 1), date(2020, 1, 2)],
"end": [date(2020, 1, 2), date(2020, 1, 10)],
}
)
with pytest.raises(
pl.ComputeError, match="`week_mask` must have at least one business day"
):
df.select(pl.business_day_count("start", "end", week_mask=[False] * 7))


def test_business_day_count_schema() -> None:
Expand Down

0 comments on commit 7635379

Please sign in to comment.