diff --git a/src/tidy_tools/model/convert.py b/src/tidy_tools/model/convert.py index 2bcc003..108f70d 100644 --- a/src/tidy_tools/model/convert.py +++ b/src/tidy_tools/model/convert.py @@ -1,6 +1,7 @@ import typing import attrs +from loguru import logger from pyspark.sql import Column from pyspark.sql import functions as F from pyspark.sql import types as T @@ -51,12 +52,24 @@ def convert_field(cls_field: attrs.Attribute, cls_field_exists: bool) -> Column: cls_field_type = get_pyspark_type(cls_field) match cls_field_type: - case T.DateType(): # TODO: make use of field.metadata.date_format - column = column.cast(cls_field_type) - case T.TimestampType(): # TODO: make use of field.metadata.date_format - column = column.cast(cls_field_type) - case _: - column = column.cast(cls_field_type) + case T.DateType(): + date_format = cls_field.metadata.get("format") + if date_format: + column = F.to_date(column, format=date_format) + else: + logger.warning( + f"No `format` provided for {cls_field.name}. Please add `field(..., metadata={{'format': ...}}) and rerun." + ) + column = column.cast(cls_field_type) + case T.TimestampType(): + timestamp_format = cls_field.metadata.get("format") + if timestamp_format: + column = F.to_timestamp(column, format=timestamp_format) + else: + logger.warning( + f"No `format` provided for {cls_field.name}. Please add `field(..., metadata={{'format': ...}}) and rerun." + ) + column = column.cast(cls_field_type) if cls_field.converter: column = cls_field.converter(column)