Skip to content

Commit

Permalink
add helper functions and unit test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Ally Franken committed Jan 15, 2025
1 parent 735f900 commit d3fa481
Show file tree
Hide file tree
Showing 2 changed files with 308 additions and 2 deletions.
120 changes: 118 additions & 2 deletions datacompy/spark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
try:
import pyspark.sql
from pyspark.sql import Window
from pyspark.sql import types as T
from pyspark.sql.functions import (
abs,
array,
Expand All @@ -49,6 +50,8 @@
trim,
upper,
when,
format_number,
row_number,
)
except ImportError:
LOG.warning(
Expand Down Expand Up @@ -1221,5 +1224,118 @@ def _is_comparable(type1: str, type2: str) -> bool:
or ({type1, type2} == {"string", "date"})
)

def detailed_compare():
return

def detailed_compare(spark_session, prod_dataframe, release_dataframe, column_to_join, string2double_cols=None):
"""
Uses DataComPy library to run a more detailed analysis on results
:param prod_dataframe: Spark Dataframe
dataset to be compared against
:param release_dataframe: Spark Dataframe
dataset to compare
:param column_to_join: List of Strings, optional
the column by which the two datasets can be joined, an identifier that indicates which rows in both datasets
should be compared
if null, the rows are compared in the order they are given
dataset to compare
:param string2double_cols: List of Strings, optional
the columns that contain numeric values but are stored as string types
:return: SparkSQLCompare object, from which a report can be generated
"""

# Convert fields that contain numeric values stored as strings to numeric types for comparison
if len(string2double_cols) != 0:
prod_dataframe = handle_numeric_strings(prod_dataframe, string2double_cols)
release_dataframe = handle_numeric_strings(release_dataframe, string2double_cols)

if len(column_to_join) == 0:
# will add a new column that numbers the rows so datasets can be compared by row number instead of by a
# common column
sorted_prod_df, sorted_release_df = sort_rows(prod_dataframe, release_dataframe)
column_to_join = ['row']
else:
sorted_prod_df = prod_dataframe
sorted_release_df = release_dataframe

print("Compared by column(s): ", column_to_join)
if string2double_cols:
print('String column(s) cast to doubles for numeric comparison: ', string2double_cols)
compared_data = SparkSQLCompare(spark_session, sorted_prod_df, sorted_release_df,
join_columns=column_to_join, abs_tol=0.0001)
return compared_data


def handle_numeric_strings(df, field_list):
for this_col in field_list:
df = df.withColumn(this_col, col(this_col).cast(T.DoubleType()))
return df


def format_numeric_fields(df):
fixed_cols = []
numeric_types = [
"tinyint",
"smallint",
"int",
"bigint",
"float",
"double",
"decimal"]

for c in df.dtypes:
# do not change non-numeric fields
if c[1] not in numeric_types:
fixed_cols.append(col(c[0]))
# round & truncate numeric fields
else:
new_val = format_number(col(c[0]), 5).alias(c[0])
fixed_cols.append(new_val)

formatted_df = df.select(*fixed_cols)
return formatted_df


def sort_rows(prod_df, release_df):
prod_cols = prod_df.columns
release_cols = release_df.columns

# Ensure both DataFrames have the same columns
for x in prod_cols:
if x not in release_cols:
raise Exception(f"{x} is present in prod_df but does not exist in release_df")

if set(prod_cols) != set(release_cols):
print('WARNING: There are columns present in Compare df that do not exist in Base df. The Base df columns will be used for row-wise sorting and may produce unanticipated report output if the extra fields are not null.')

w = Window.orderBy(*prod_cols)
sorted_prod_df = prod_df.select('*', row_number().over(w).alias('row'))
sorted_release_df = release_df.select('*', row_number().over(w).alias('row'))
return sorted_prod_df, sorted_release_df


def sort_columns(prod_df, release_df):
# Ensure both DataFrames have the same columns
common_columns = set(prod_df.columns)
for x in common_columns:
if x not in release_df.columns:
raise Exception(f"{x} is present in prod_df but does not exist in release_df")
# Sort both DataFrames to ensure consistent order
prod_df = prod_df.orderBy(*common_columns)
release_df = release_df.orderBy(*common_columns)
return prod_df, release_df


def convert_exponential_strings(base_df, compare_df):
# convert scientific number (1.23E4) to a decimal value
def sci_no_to_decimal(value):
return when(col(value).rlike(r"^[-+]?[0-9]*\.?[0-9]+[eE][0-9]+"),
col(value).cast(T.DecimalType(30, 10))).otherwise(col(value))

df_return_list = []

for df in [base_df, compare_df]:
for column in df.columns:
if column in df.columns and df.schema[column].dataType == T.StringType():
df = df.withColumn(column, sci_no_to_decimal(column))
df_return_list.append(df)

return df_return_list[0], df_return_list[1]
190 changes: 190 additions & 0 deletions tests/test_spark/test_sql_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,15 @@
calculate_max_diff,
columns_equal,
temp_column_name,
sort_columns,
sort_rows,
format_numeric_fields,
handle_numeric_strings,
detailed_compare,
)
from pandas.testing import assert_series_equal
from pyspark.sql.types import StringType, StructField, StructType, IntegerType, DoubleType, BooleanType, Row


logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)

Expand Down Expand Up @@ -1344,3 +1351,186 @@ def test_unicode_columns(spark_session):
assert compare.matches()
# Just render the report to make sure it renders.
compare.report()


def test_detailed_compare_with_string2columns(spark_session):
# create mock data
mock_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
mock_prod_columns = ["name", "age", "pet"]

mock_release_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
mock_release_columns = ["name", "age", "pet"]

# Create DataFrames
mock_prod_df = spark_session.createDataFrame(mock_prod_data, mock_prod_columns)
mock_release_df = spark_session.createDataFrame(mock_release_data, mock_release_columns)

# call detailed_compare
result_compared_data = detailed_compare(spark_session, mock_prod_df, mock_release_df, [], ["age"])

# assert result
assert result_compared_data.matches()
assert isinstance(result_compared_data, SparkSQLCompare)


def test_detailed_compare_with_column_to_join(spark_session):
# create mock data
mock_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
mock_prod_columns = ["name", "age", "pet"]

mock_release_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
mock_release_columns = ["name", "age", "pet"]

# Create DataFrames
mock_prod_df = spark_session.createDataFrame(mock_prod_data, mock_prod_columns)
mock_release_df = spark_session.createDataFrame(mock_release_data, mock_release_columns)

# call detailed_compare
result_compared_data = detailed_compare(spark_session, mock_prod_df, mock_release_df, ["name"], [])

# assert result
assert result_compared_data.matches()
assert isinstance(result_compared_data, SparkSQLCompare)


def test_handle_numeric_strings(spark_session):
# create mock_df
mock_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
mock_columns = ["name", "age", "pet"]
mock_df = spark_session.createDataFrame(mock_data, mock_columns)

# create mock field_list
mock_field_list = ["age"]

# call handle_numeric_strings
result_df = handle_numeric_strings(mock_df, mock_field_list)

# create expected dataframe
expected_data = [("bob", 22.0, "dog"), ("alice", 19.0, "cat"), ("john", 70.0, "bunny")]
expected_columns = ["name", "age", "pet"]
expected_df = spark_session.createDataFrame(expected_data, expected_columns)

# assert calls
assert result_df.collect() == expected_df.collect()


def test_format_numeric_fields(spark_session):
# create mock dataframe
mock_data = [("bob", 22, "dog"), ("alice", 19, "cat"), ("john", 70, "bunny")]
mock_columns = ["name", "age", "pet"]
mock_df = spark_session.createDataFrame(mock_data, mock_columns)

# call format_numeric_fields
formatted_df = format_numeric_fields(mock_df)

# create expected dataframe
expected_data = [("bob", "22.00000", "dog"), ("alice", "19.00000", "cat"), ("john", "70.00000", "bunny")]
expected_columns = ["name", "age", "pet"]
expected_df = spark_session.createDataFrame(expected_data, expected_columns)

# assert calls
assert formatted_df.collect() == expected_df.collect()


def test_sort_rows_failure(spark_session):
# create mock dataframes
input_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
columns_prod = ["name", "age", "pet"]

input_release_data = [("19", "cat"), ("70", "bunny"), ("22", "dog")]
columns_release = ["age", "pet"]

# Create DataFrames
input_prod_df = spark_session.createDataFrame(input_prod_data, columns_prod)
input_release_df = spark_session.createDataFrame(input_release_data, columns_release)

# call call_rows
with pytest.raises(Exception, match="name is present in prod_df but does not exist in release_df"):
sort_rows(input_prod_df, input_release_df)


def test_sort_rows_success(capsys, spark_session):
# create mock data
input_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
columns_prod = ["name", "age", "pet"]

input_release_data = [("19", "cat", "alice", "red"), ("70", "bunny", "john", "black"), ("22", "dog", "bob", "white")]
columns_release = ["age", "pet", "name", "color"]

# create dataFrames
input_prod_df = spark_session.createDataFrame(input_prod_data, columns_prod)
input_release_df = spark_session.createDataFrame(input_release_data, columns_release)

# call sort_rows
sorted_prod_df, sorted_release_df = sort_rows(input_prod_df, input_release_df)

# create expected prod_dataframe
expected_prod_data = [("alice", "19", "cat", 1), ("bob", "22", "dog", 2), ("john", "70", "bunny", 3)]
expected_prod_schema = StructType([
StructField("name", StringType(), True),
StructField("age", StringType(), True),
StructField("pet", StringType(), True),
StructField("row", IntegerType(), True)
])
expected_prod_df = spark_session.createDataFrame(expected_prod_data, expected_prod_schema)

# create expected release_dataframe
expected_release_data = [("19", "cat", "alice", "red", 1), ("22", "dog", "bob", "white", 2), ("70", "bunny", "john", "black", 3)]
expected_release_schema = StructType([
StructField("age", StringType(), True),
StructField("pet", StringType(), True),
StructField("name", StringType(), True),
StructField("color", StringType(), True),
StructField("row", IntegerType(), True)
])
expected_release_df = spark_session.createDataFrame(expected_release_data, expected_release_schema)

# assertions
assert sorted_prod_df.collect() == expected_prod_df.collect()
assert sorted_release_df.collect() == expected_release_df.collect()

captured = capsys.readouterr()
assert captured.out == "WARNING: There are columns present in Compare df that do not exist in Base df. The Base df columns will be used for row-wise sorting and may produce unanticipated report output if the extra fields are not null.\n"


def test_sort_columns_failure(spark_session):
# create mock dataframes
input_prod_data = [("row1", "col2", "col3"), ("row2", "col2", "col3"), ("row3", "col2", "col3")]
columns_1 = ["col1", "col2", "col3"]

input_release_data = [("row1", "col2"), ("row2", "col2"), ("row3", "col2")]
columns_2 = ["col1", "col2"]

# Create DataFrames
input_prod_df = spark_session.createDataFrame(input_prod_data, columns_1)
input_release_df = spark_session.createDataFrame(input_release_data, columns_2)

# call sort_columns
with pytest.raises(Exception, match="col3 is present in prod_df but does not exist in release_df"):
sort_columns(input_prod_df, input_release_df)


def test_sort_columns_success(spark_session):
# create mock dataframes
input_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
columns_prod = ["name", "age", "pet"]

input_release_data = [("19", "cat", "alice"), ("70", "bunny", "john"), ("22", "dog", "bob")]
columns_release = ["age", "pet", "name"]

# create input dataFrames
input_prod_df = spark_session.createDataFrame(input_prod_data, columns_prod)
input_release_df = spark_session.createDataFrame(input_release_data, columns_release)

# create expected dataFrames
expected_prod_data = [("alice", "19", "cat"), ("bob", "22", "dog"), ("john", "70", "bunny")]
expected_release_data = [("19", "cat", "alice"), ("22", "dog", "bob"), ("70", "bunny", "john")]
expected_prod_df = spark_session.createDataFrame(expected_prod_data, columns_prod)
expected_release_df = spark_session.createDataFrame(expected_release_data, columns_release)

# call sort_columns
output_prod_df, output_release_df = sort_columns(input_prod_df, input_release_df)

# assert the dfs are equal
assert output_prod_df.collect() == expected_prod_df.collect()
assert output_release_df.collect() == expected_release_df.collect()

0 comments on commit d3fa481

Please sign in to comment.