diff --git a/dataframely/columns/datetime.py b/dataframely/columns/datetime.py index cf71f83..ca09d4f 100644 --- a/dataframely/columns/datetime.py +++ b/dataframely/columns/datetime.py @@ -9,7 +9,12 @@ import polars as pl -from dataframely._compat import pa, sa, sa_mssql, sa_TypeEngine +from dataframely._compat import ( + pa, + sa, + sa_mssql, + sa_TypeEngine, +) from dataframely._polars import ( EPOCH_DATETIME, date_matches_resolution, @@ -265,6 +270,7 @@ def __init__( max: dt.datetime | None = None, max_exclusive: dt.datetime | None = None, resolution: str | None = None, + time_zone: str | dt.tzinfo | None = None, check: Callable[[pl.Expr], pl.Expr] | None = None, alias: str | None = None, metadata: dict[str, Any] | None = None, @@ -284,6 +290,9 @@ def __init__( the formatting language used by :mod:`polars` datetime ``round`` method. For example, a value ``1h`` expects all datetimes to be full hours. Note that this setting does *not* affect the storage resolution. + time_zone: The time zone that datetimes in the column must have. The time + zone must use a valid IANA time zone name identifier e.x. ``Etc/UTC`` or + ``America/New_York``. check: A custom check to run for this column. Must return a non-aggregated boolean expression. alias: An overwrite for this column's name which allows for using a column @@ -317,10 +326,11 @@ def __init__( metadata=metadata, ) self.resolution = resolution + self.time_zone = time_zone @property def dtype(self) -> pl.DataType: - return pl.Datetime() + return pl.Datetime(time_zone=self.time_zone) def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: result = super().validation_rules(expr) @@ -329,16 +339,22 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: return result def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: + timezone_enabled = self.time_zone is not None match dialect.name: case "mssql": # sa.DateTime wrongly maps to DATETIME - return sa_mssql.DATETIME2(6) + return sa_mssql.DATETIME2(6, timezone=timezone_enabled) case _: - return sa.DateTime() + return sa.DateTime(timezone=timezone_enabled) @property def pyarrow_dtype(self) -> pa.DataType: - return pa.timestamp("us") + time_zone = ( + self.time_zone.tzname(None) + if isinstance(self.time_zone, dt.timezone) + else self.time_zone + ) + return pa.timestamp("us", time_zone) def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return generator.sample_datetime( @@ -354,6 +370,7 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: allow_null_response=True, ), resolution=self.resolution, + time_zone=self.time_zone, null_probability=self._null_probability, ) diff --git a/dataframely/random.py b/dataframely/random.py index 076a99a..ae955d6 100644 --- a/dataframely/random.py +++ b/dataframely/random.py @@ -293,6 +293,7 @@ def sample_datetime( min: dt.datetime, max: dt.datetime | None, resolution: str | None = None, + time_zone: str | dt.tzinfo | None = None, null_probability: float = 0.0, ) -> pl.Series: """Sample a list of datetimes in the provided range. @@ -303,6 +304,9 @@ def sample_datetime( max: The maximum datetime to sample (exclusive). '10000-01-01' when ``None``. resolution: The resolution that datetimes in the column must have. This uses the formatting language used by :mod:`polars` datetime ``round`` method. + time_zone: The time zone that datetimes in the column must have. The time + zone must use a valid IANA time zone name identifier e.x. ``Etc/UTC`` or + ``America/New_York``. null_probability: The probability of an element being ``null``. Returns: @@ -329,7 +333,7 @@ def sample_datetime( ) # NOTE: polars tracks datetimes relative to epoch - _datetime_to_microseconds(EPOCH_DATETIME) - ).cast(pl.Datetime) + ).cast(pl.Datetime(time_zone=time_zone)) if resolution is not None: return result.dt.truncate(resolution) diff --git a/tests/column_types/test_datetime.py b/tests/column_types/test_datetime.py index d67a2fb..a37ff4a 100644 --- a/tests/column_types/test_datetime.py +++ b/tests/column_types/test_datetime.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import datetime as dt +import re from typing import Any import polars as pl @@ -10,6 +11,7 @@ import dataframely as dy from dataframely.columns import Column +from dataframely.exc import DtypeValidationError from dataframely.random import Generator from dataframely.testing import evaluate_rules, rules_from_exprs from dataframely.testing.factory import create_schema @@ -392,11 +394,42 @@ def test_validate_resolution( [ dy.Datetime( min=dt.datetime(2020, 1, 1), max=dt.datetime(2021, 1, 1), resolution="1h" - ) + ), + dy.Datetime(time_zone="Etc/UTC"), ], ) -def test_sample_resolution(column: dy.Column) -> None: +def test_sample(column: dy.Column) -> None: generator = Generator(seed=42) samples = column.sample(generator, n=10_000) schema = create_schema("test", {"a": column}) schema.validate(samples.to_frame("a")) + + +@pytest.mark.parametrize( + ("dtype", "column", "error"), + [ + ( + pl.Datetime(time_zone="America/New_York"), + dy.Datetime(time_zone="Etc/UTC"), + r"1 columns have an invalid dtype.*\n.*got dtype 'Datetime\(time_unit='us', time_zone='America/New_York'\)' but expected 'Datetime\(time_unit='us', time_zone='Etc/UTC'\)'", + ), + ( + pl.Datetime(time_zone="Etc/UTC"), + dy.Datetime(time_zone="Etc/UTC"), + None, + ), + ], +) +def test_dtype_time_zone_validation( + dtype: pl.DataType, + column: dy.Column, + error: str | None, +) -> None: + df = pl.DataFrame(schema={"a": dtype}) + schema = create_schema("test", {"a": column}) + if error is None: + schema.validate(df) + else: + with pytest.raises(DtypeValidationError) as exc: + schema.validate(df) + assert re.match(error, str(exc.value)) diff --git a/tests/columns/test_sql_schema.py b/tests/columns/test_sql_schema.py index ea2de2a..740a143 100644 --- a/tests/columns/test_sql_schema.py +++ b/tests/columns/test_sql_schema.py @@ -18,6 +18,7 @@ (dy.Bool(), "BIT"), (dy.Date(), "DATE"), (dy.Datetime(), "DATETIME2(6)"), + (dy.Datetime(time_zone="Etc/UTC"), "DATETIME2(6)"), (dy.Time(), "TIME(6)"), (dy.Duration(), "DATETIME2(6)"), (dy.Decimal(), "NUMERIC"), @@ -62,6 +63,7 @@ def test_mssql_datatype(column: Column, datatype: str) -> None: (dy.Bool(), "BOOLEAN"), (dy.Date(), "DATE"), (dy.Datetime(), "TIMESTAMP WITHOUT TIME ZONE"), + (dy.Datetime(time_zone="Etc/UTC"), "TIMESTAMP WITH TIME ZONE"), (dy.Time(), "TIME WITHOUT TIME ZONE"), (dy.Duration(), "INTERVAL"), (dy.Decimal(), "NUMERIC"),