Skip to content

Refactor to remove try_cast spark function #163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions src/databricks/labs/dqx/col_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def is_not_null_and_not_empty(col_name: str, trim_strings: bool = False) -> Colu
column = F.col(col_name)
if trim_strings:
column = F.trim(column).alias(col_name)
condition = column.isNull() | (column.try_cast("string") == F.lit(""))
condition = column.isNull() | (column.cast("string").isNull() | (column.cast("string") == F.lit("")))
return make_condition(condition, f"Column {col_name} is null or empty", f"{col_name}_is_null_or_empty")


Expand All @@ -48,8 +48,8 @@ def is_not_empty(col_name: str) -> Column:
:return: Column object for condition
"""
column = F.col(col_name)
column = column.try_cast("string")
return make_condition((column == ""), f"Column {col_name} is empty", f"{col_name}_is_empty")
condition = column.cast("string") == F.lit("")
return make_condition(condition, f"Column {col_name} is empty", f"{col_name}_is_empty")


def is_not_null(col_name: str) -> Column:
Expand Down Expand Up @@ -77,7 +77,7 @@ def value_is_not_null_and_is_in_list(col_name: str, allowed: list) -> Column:
F.concat_ws(
"",
F.lit("Value "),
F.when(column.isNull(), F.lit("null")).otherwise(column.try_cast("string")),
F.when(column.isNull(), F.lit("null")).otherwise(column.cast("string")),
F.lit(" is not in the allowed list: ["),
F.concat_ws(", ", *allowed_cols),
F.lit("]"),
Expand Down Expand Up @@ -381,15 +381,15 @@ def is_valid_date(col_name: str, date_format: str | None = None) -> Column:
:param date_format: date format (e.g. 'yyyy-mm-dd')
:return: Column object for condition
"""
str_col = F.col(col_name)
date_col = str_col.try_cast("date") if date_format is None else F.try_to_timestamp(str_col, F.lit(date_format))
condition = F.when(str_col.isNull(), F.lit(None)).otherwise(date_col.isNull())
column = F.col(col_name)
date_col = F.try_to_timestamp(column) if date_format is None else F.try_to_timestamp(column, F.lit(date_format))
condition = F.when(column.isNull(), F.lit(None)).otherwise(date_col.isNull())
condition_str = "' is not a valid date"
if date_format is not None:
condition_str += f" with format '{date_format}'"
return make_condition(
condition,
F.concat_ws("", F.lit("Value '"), str_col, F.lit(condition_str)),
F.concat_ws("", F.lit("Value '"), column, F.lit(condition_str)),
f"{col_name}_is_not_valid_date",
)

Expand All @@ -401,18 +401,16 @@ def is_valid_timestamp(col_name: str, timestamp_format: str | None = None) -> Co
:param timestamp_format: timestamp format (e.g. 'yyyy-mm-dd HH:mm:ss')
:return: Column object for condition
"""
str_col = F.col(col_name)
column = F.col(col_name)
ts_col = (
str_col.try_cast("timestamp")
if timestamp_format is None
else F.try_to_timestamp(str_col, F.lit(timestamp_format))
F.try_to_timestamp(column) if timestamp_format is None else F.try_to_timestamp(column, F.lit(timestamp_format))
)
condition = F.when(str_col.isNull(), F.lit(None)).otherwise(ts_col.isNull())
condition = F.when(column.isNull(), F.lit(None)).otherwise(ts_col.isNull())
condition_str = "' is not a valid timestamp"
if timestamp_format is not None:
condition_str += f" with format '{timestamp_format}'"
return make_condition(
condition,
F.concat_ws("", F.lit("Value '"), str_col, F.lit(condition_str)),
F.concat_ws("", F.lit("Value '"), column, F.lit(condition_str)),
f"{col_name}_is_not_valid_timestamp",
)
11 changes: 11 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ def product_info():
return "dqx", __version__


@pytest.fixture
def set_utc_timezone():
"""
Set the timezone to UTC for the duration of the test to make sure spark timestamps
are handled the same way regardless of the environment.
"""
os.environ["TZ"] = "UTC"
yield
os.environ.pop("TZ")


@pytest.fixture
def make_check_file_as_yaml(ws, make_random, make_directory):
def create(**kwargs):
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/test_apply_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,7 @@ def test_apply_checks_by_metadata_with_func_defined_outside_framework(ws, spark)

def col_test_check_func(col_name: str) -> Column:
check_col = F.col(col_name)
check_col = check_col.try_cast("string")
condition = check_col.isNull() | (check_col == "") | (check_col == "null")
condition = check_col.isNull() | (check_col.cast("string").isNull() | (check_col.cast("string") == F.lit("")))
return make_condition(condition, "new check failed", f"{col_name}_is_null_or_empty")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def test_is_col_older_than_n_days_cur(spark):
assert_df_equality(actual, expected, ignore_nullable=True)


def test_col_not_less_than(spark):
def test_col_not_less_than(spark, set_utc_timezone):
schema_num = "a: int, b: date, c: timestamp"
test_df = spark.createDataFrame(
[
Expand Down Expand Up @@ -254,7 +254,7 @@ def test_col_not_less_than(spark):
assert_df_equality(actual, expected, ignore_nullable=True)


def test_col_not_greater_than(spark):
def test_col_not_greater_than(spark, set_utc_timezone):
schema_num = "a: int, b: date, c: timestamp"
test_df = spark.createDataFrame(
[
Expand Down Expand Up @@ -288,7 +288,7 @@ def test_col_not_greater_than(spark):
assert_df_equality(actual, expected, ignore_nullable=True)


def test_col_is_in_range(spark):
def test_col_is_in_range(spark, set_utc_timezone):
schema_num = "a: int, b: date, c: timestamp"
test_df = spark.createDataFrame(
[
Expand Down Expand Up @@ -334,7 +334,7 @@ def test_col_is_in_range(spark):
assert_df_equality(actual, expected, ignore_nullable=True)


def test_col_is_not_in_range(spark):
def test_col_is_not_in_range(spark, set_utc_timezone):
schema_num = "a: int, b: date, c: timestamp"
test_df = spark.createDataFrame(
[
Expand Down Expand Up @@ -486,7 +486,7 @@ def test_col_is_not_null_and_not_empty_array(spark):
assert_df_equality(actual, expected, ignore_nullable=True)


def test_col_is_valid_date(spark):
def test_col_is_valid_date(spark, set_utc_timezone):
schema_array = "a: string, b: string, c: string, d: string"
data = [
["2024-01-01", "12/31/2025", "invalid_date", None],
Expand Down Expand Up @@ -526,7 +526,7 @@ def test_col_is_valid_date(spark):
assert_df_equality(actual, expected, ignore_nullable=True)


def test_col_is_valid_timestamp(spark):
def test_col_is_valid_timestamp(spark, set_utc_timezone):
schema_array = "a: string, b: string, c: string, d: string, e: string"
data = [
["2024-01-01 00:00:00", "12/31/2025 00:00:00", "invalid_timestamp", None, "2025-01-31T00:00:00"],
Expand Down