Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: type check tests #49

Merged
merged 3 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
- run: make venv
- run: make pre-commit
- run: make install
- run: venv/bin/python -m pytest tests && venv/bin/python -m pytest --doctest-modules polars_xdt
- run: make test

linux:
runs-on: ubuntu-latest
Expand Down
8 changes: 6 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ install-release: venv

pre-commit: venv
cargo fmt --all && cargo clippy --all-features
venv/bin/python -m ruff check .
venv/bin/python -m ruff check . --fix --exit-non-zero-on-fix
venv/bin/python -m ruff format
venv/bin/python -m mypy polars_xdt
venv/bin/python -m mypy polars_xdt tests

test: venv
venv/bin/python -m pytest tests
venv/bin/python -m pytest polars_xdt --doctest-modules

run: install
source venv/bin/activate && python run.py
Expand Down
1 change: 1 addition & 0 deletions bump_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# ruff: noqa
# type: ignore
import sys
import re
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# ruff: noqa
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
Expand Down
81 changes: 40 additions & 41 deletions polars_xdt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from __future__ import annotations

import polars as pl
from polars.utils.udfs import _get_shared_lib_location
import re
from datetime import date
import sys
from datetime import date
from typing import TYPE_CHECKING, Iterable, Literal, Protocol, Sequence, cast

import polars as pl
from polars.utils._parse_expr_input import parse_as_expression
from polars.utils._wrap import wrap_expr
from polars_xdt.ranges import date_range
from polars.utils.udfs import _get_shared_lib_location

from polars.type_aliases import PolarsDataType
from typing import Iterable, Literal, Protocol, Sequence, cast, TYPE_CHECKING
from polars_xdt.ranges import date_range

from ._internal import __version__ as __version__
from ._internal import __version__

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

if TYPE_CHECKING:
from polars.type_aliases import PolarsDataType

RollStrategy: TypeAlias = Literal["raise", "forward", "backward"]


Expand All @@ -36,22 +39,18 @@ def get_weekmask(weekend: Sequence[str]) -> list[bool]:
if weekend == ("Sat", "Sun"):
weekmask = [True, True, True, True, True, False, False]
else:
weekmask = [
False if reverse_mapping[i] in weekend else True
for i in range(1, 8)
]
weekmask = [reverse_mapping[i] not in weekend for i in range(1, 8)]
if sum(weekmask) == 0:
raise ValueError(
f"At least one day of the week must be a business day. Got weekend={weekend}"
)
msg = f"At least one day of the week must be a business day. Got weekend={weekend}"
raise ValueError(msg)
return weekmask


@pl.api.register_expr_namespace("xdt")
class ExprXDTNamespace:
"""eXtra stuff for DateTimes."""

def __init__(self, expr: pl.Expr):
def __init__(self, expr: pl.Expr) -> None:
self._expr = expr

def offset_by(
Expand All @@ -61,7 +60,7 @@ def offset_by(
weekend: Sequence[str] = ("Sat", "Sun"),
holidays: Sequence[date] | None = None,
roll: RollStrategy = "raise",
) -> xdtExpr:
) -> XDTExpr:
"""Offset this date by a relative time offset.

Parameters
Expand Down Expand Up @@ -169,7 +168,7 @@ def offset_by(
if not isinstance(by, pl.Expr):
by = pl.lit(by)
n = (by.str.extract(r"^(-?)") + by.str.extract(r"(\d+)bd")).cast(
pl.Int32
pl.Int32,
)
by = by.str.replace(r"(\d+bd)", "")
fastpath = False
Expand All @@ -178,7 +177,7 @@ def offset_by(
holidays_int = []
else:
holidays_int = sorted(
{(holiday - date(1970, 1, 1)).days for holiday in holidays}
{(holiday - date(1970, 1, 1)).days for holiday in holidays},
)
weekmask = get_weekmask(weekend)

Expand All @@ -194,22 +193,22 @@ def offset_by(
},
)
if fastpath:
return cast(xdtExpr, result)
return cast(xdtExpr, result.dt.offset_by(by))
return cast(XDTExpr, result)
return cast(XDTExpr, result.dt.offset_by(by))

def sub(
self,
end_dates: str | pl.Expr,
*,
weekend: Sequence[str] = ("Sat", "Sun"),
holidays: Sequence[date] | None = None,
) -> xdtExpr:
) -> XDTExpr:
weekmask = get_weekmask(weekend)
if not holidays:
holidays_int = []
else:
holidays_int = sorted(
{(holiday - date(1970, 1, 1)).days for holiday in holidays}
{(holiday - date(1970, 1, 1)).days for holiday in holidays},
)
if isinstance(end_dates, str):
end_dates = pl.col(end_dates)
Expand All @@ -223,7 +222,7 @@ def sub(
"holidays": holidays_int,
},
)
return cast(xdtExpr, result)
return cast(XDTExpr, result)

def is_workday(
self,
Expand Down Expand Up @@ -276,9 +275,9 @@ def is_workday(
holidays_int = []
else:
holidays_int = sorted(
{(holiday - date(1970, 1, 1)).days for holiday in holidays}
{(holiday - date(1970, 1, 1)).days for holiday in holidays},
)
result = self._expr.register_plugin(
return self._expr.register_plugin(
lib=lib,
symbol="is_workday",
is_elementwise=True,
Expand All @@ -288,14 +287,13 @@ def is_workday(
"holidays": holidays_int,
},
)
return result

def from_local_datetime(
self,
from_tz: str | Expr,
to_tz: str,
ambiguous: Ambiguous = "raise",
) -> xdtExpr:
) -> XDTExpr:
"""Converts from local datetime in given time zone to new timezone.

Parameters
Expand Down Expand Up @@ -364,12 +362,12 @@ def from_local_datetime(
"ambiguous": ambiguous,
},
)
return cast(xdtExpr, result)
return cast(XDTExpr, result)

def to_local_datetime(
self,
time_zone: str | Expr,
) -> xdtExpr:
) -> XDTExpr:
"""Convert to local datetime in given time zone.

Parameters
Expand Down Expand Up @@ -422,13 +420,13 @@ def to_local_datetime(
is_elementwise=True,
args=[time_zone],
)
return cast(xdtExpr, result)
return cast(XDTExpr, result)

def format_localized(
self,
format: str,
format: str, # noqa: A002
locale: str = "uk_UA",
) -> xdtExpr:
) -> XDTExpr:
"""Convert to local datetime in given time zone.

Parameters
Expand Down Expand Up @@ -476,12 +474,12 @@ def format_localized(
args=[],
kwargs={"format": format, "locale": locale},
)
return cast(xdtExpr, result)
return cast(XDTExpr, result)

def ceil(
self,
every: str | pl.Expr,
) -> xdtExpr:
) -> XDTExpr:
"""Find "ceiling" of datetime.

Parameters
Expand Down Expand Up @@ -541,21 +539,21 @@ def ceil(
.then(self._expr)
.otherwise(truncated.dt.offset_by(every))
)
return cast(xdtExpr, result)
return cast(XDTExpr, result)


class xdtExpr(pl.Expr):
class XDTExpr(pl.Expr):
@property
def xdt(self) -> ExprXDTNamespace:
return ExprXDTNamespace(self)


class xdtColumn(Protocol):
class XDTColumn(Protocol):
def __call__(
self,
name: str | PolarsDataType | Iterable[str] | Iterable[PolarsDataType],
*more_names: str | PolarsDataType,
) -> xdtExpr:
) -> XDTExpr:
...

def __getattr__(self, name: str) -> pl.Expr:
Expand All @@ -566,15 +564,15 @@ def xdt(self) -> ExprXDTNamespace:
...


col = cast(xdtColumn, pl.col)
col = cast(XDTColumn, pl.col)


def workday_count(
start: str | pl.Expr | date,
end: str | pl.Expr | date,
weekend: Sequence[str] = ("Sat", "Sun"),
holidays: Sequence[date] | None = None,
) -> xdtExpr:
) -> XDTExpr:
"""Count the number of workdays between two columns of dates.

Parameters
Expand Down Expand Up @@ -628,12 +626,13 @@ def workday_count(
end = pl.lit(end)

return end.xdt.sub(start, weekend=weekend, holidays=holidays).alias( # type: ignore[no-any-return, attr-defined]
"workday_count"
"workday_count",
)


__all__ = [
"col",
"date_range",
"workday_count",
"__version__",
]
17 changes: 9 additions & 8 deletions polars_xdt/ranges.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from __future__ import annotations

from typing import overload
import re
from typing import TYPE_CHECKING, Literal, Sequence, overload

from datetime import datetime, date, timedelta
from typing import Literal, Sequence
import polars as pl
from polars.type_aliases import IntoExprColumn, ClosedInterval, TimeUnit

mapping = {"Mon": 1, "Tue": 2, "Wed": 3, "Thu": 4, "Fri": 5, "Sat": 6, "Sun": 7}

if TYPE_CHECKING:
from datetime import date, datetime, timedelta

from polars.type_aliases import ClosedInterval, IntoExprColumn, TimeUnit


@overload
def date_range(
Expand Down Expand Up @@ -59,7 +61,7 @@ def date_range(
...


def date_range(
def date_range( # noqa: PLR0913
start: date | datetime | IntoExprColumn,
end: date | datetime | IntoExprColumn,
interval: str | timedelta = "1bd",
Expand Down Expand Up @@ -142,9 +144,8 @@ def date_range(
holidays = []

if not (isinstance(interval, str) and re.match(r"^-?\d+bd$", interval)):
raise ValueError(
"Only intervals of the form 'nbd' (where n is an integer) are supported."
)
msg = "Only intervals of the form 'nbd' (where n is an integer) are supported."
raise ValueError(msg)
interval = interval.replace("bd", "d")

expr = pl.date_range(
Expand Down
Loading