From d3fa4814c85f2d2f9ed875593aa098d565cb6ce4 Mon Sep 17 00:00:00 2001
From: Ally Franken <ally.franken@capitalone.com>
Date: Wed, 15 Jan 2025 13:41:08 -0500
Subject: [PATCH] add helper functions and unit test cases

---
 datacompy/spark/sql.py             | 120 +++++++++++++++++-
 tests/test_spark/test_sql_spark.py | 190 +++++++++++++++++++++++++++++
 2 files changed, 308 insertions(+), 2 deletions(-)

diff --git a/datacompy/spark/sql.py b/datacompy/spark/sql.py
index abcfea9e..4f3a507e 100644
--- a/datacompy/spark/sql.py
+++ b/datacompy/spark/sql.py
@@ -36,6 +36,7 @@
 try:
     import pyspark.sql
     from pyspark.sql import Window
+    from pyspark.sql import types as T
     from pyspark.sql.functions import (
         abs,
         array,
@@ -49,6 +50,8 @@
         trim,
         upper,
         when,
+        format_number,
+        row_number,
     )
 except ImportError:
     LOG.warning(
@@ -1221,5 +1224,118 @@ def _is_comparable(type1: str, type2: str) -> bool:
         or ({type1, type2} == {"string", "date"})
     )
 
-def detailed_compare():
-    return
\ No newline at end of file
+
+def detailed_compare(spark_session, prod_dataframe, release_dataframe, column_to_join, string2double_cols=None):
+    """
+    Uses DataComPy library to run a more detailed analysis on results
+    :param prod_dataframe: Spark Dataframe
+        dataset to be compared against
+    :param release_dataframe: Spark Dataframe
+        dataset to compare
+    :param column_to_join: List of Strings, optional
+        the column by which the two datasets can be joined, an identifier that indicates which rows in both datasets
+        should be compared
+        if null, the rows are compared in the order they are given
+            dataset to compare
+    :param string2double_cols: List of Strings, optional
+        the columns that contain numeric values but are stored as string types
+    :return: SparkSQLCompare object, from which a report can be generated
+    """
+
+    # Convert fields that contain numeric values stored as strings to numeric types for comparison
+    if len(string2double_cols) != 0:
+        prod_dataframe = handle_numeric_strings(prod_dataframe, string2double_cols)
+        release_dataframe = handle_numeric_strings(release_dataframe, string2double_cols)
+
+    if len(column_to_join) == 0:
+        # will add a new column that numbers the rows so datasets can be compared by row number instead of by a
+        # common column
+        sorted_prod_df, sorted_release_df = sort_rows(prod_dataframe, release_dataframe)
+        column_to_join = ['row']
+    else:
+        sorted_prod_df = prod_dataframe
+        sorted_release_df = release_dataframe
+
+    print("Compared by column(s): ", column_to_join)
+    if string2double_cols:
+        print('String column(s) cast to doubles for numeric comparison: ', string2double_cols)
+    compared_data = SparkSQLCompare(spark_session, sorted_prod_df, sorted_release_df,
+                                 join_columns=column_to_join, abs_tol=0.0001)
+    return compared_data
+
+
+def handle_numeric_strings(df, field_list):
+    for this_col in field_list:
+        df = df.withColumn(this_col, col(this_col).cast(T.DoubleType()))
+    return df
+
+
+def format_numeric_fields(df):
+    fixed_cols = []
+    numeric_types = [
+        "tinyint",
+        "smallint",
+        "int",
+        "bigint",
+        "float",
+        "double",
+        "decimal"]
+
+    for c in df.dtypes:
+        # do not change non-numeric fields
+        if c[1] not in numeric_types:
+            fixed_cols.append(col(c[0]))
+        # round & truncate numeric fields
+        else:
+            new_val = format_number(col(c[0]), 5).alias(c[0])
+            fixed_cols.append(new_val)
+
+    formatted_df = df.select(*fixed_cols)
+    return formatted_df
+
+
+def sort_rows(prod_df, release_df):
+    prod_cols = prod_df.columns
+    release_cols = release_df.columns
+
+    # Ensure both DataFrames have the same columns
+    for x in prod_cols:
+        if x not in release_cols:
+            raise Exception(f"{x} is present in prod_df but does not exist in release_df")
+
+    if set(prod_cols) != set(release_cols):
+        print('WARNING: There are columns present in Compare df that do not exist in Base df. The Base df columns will be used for row-wise sorting and may produce unanticipated report output if the extra fields are not null.')
+
+    w = Window.orderBy(*prod_cols)
+    sorted_prod_df = prod_df.select('*', row_number().over(w).alias('row'))
+    sorted_release_df = release_df.select('*', row_number().over(w).alias('row'))
+    return sorted_prod_df, sorted_release_df
+
+
+def sort_columns(prod_df, release_df):
+    # Ensure both DataFrames have the same columns
+    common_columns = set(prod_df.columns)
+    for x in common_columns:
+        if x not in release_df.columns:
+            raise Exception(f"{x} is present in prod_df but does not exist in release_df")
+    # Sort both DataFrames to ensure consistent order
+    prod_df = prod_df.orderBy(*common_columns)
+    release_df = release_df.orderBy(*common_columns)
+    return prod_df, release_df
+
+
+def convert_exponential_strings(base_df, compare_df):
+    # convert scientific number (1.23E4) to a decimal value
+    def sci_no_to_decimal(value):
+        return when(col(value).rlike(r"^[-+]?[0-9]*\.?[0-9]+[eE][0-9]+"),
+                      col(value).cast(T.DecimalType(30, 10))).otherwise(col(value))
+
+    df_return_list = []
+
+    for df in [base_df, compare_df]:
+        for column in df.columns:
+            if column in df.columns and df.schema[column].dataType == T.StringType():
+                df = df.withColumn(column, sci_no_to_decimal(column))
+        df_return_list.append(df)
+
+    return df_return_list[0], df_return_list[1]
\ No newline at end of file
diff --git a/tests/test_spark/test_sql_spark.py b/tests/test_spark/test_sql_spark.py
index e93fbafc..b5ac7468 100644
--- a/tests/test_spark/test_sql_spark.py
+++ b/tests/test_spark/test_sql_spark.py
@@ -41,8 +41,15 @@
     calculate_max_diff,
     columns_equal,
     temp_column_name,
+    sort_columns,
+    sort_rows,
+    format_numeric_fields,
+    handle_numeric_strings,
+    detailed_compare,
 )
 from pandas.testing import assert_series_equal
+from pyspark.sql.types import StringType, StructField, StructType, IntegerType, DoubleType, BooleanType, Row
+
 
 logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
 
@@ -1344,3 +1351,186 @@ def test_unicode_columns(spark_session):
     assert compare.matches()
     # Just render the report to make sure it renders.
     compare.report()
+
+
+def test_detailed_compare_with_string2columns(spark_session):
+    # create mock data
+    mock_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
+    mock_prod_columns = ["name", "age", "pet"]
+
+    mock_release_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
+    mock_release_columns = ["name", "age", "pet"]
+
+    # Create DataFrames
+    mock_prod_df = spark_session.createDataFrame(mock_prod_data, mock_prod_columns)
+    mock_release_df = spark_session.createDataFrame(mock_release_data, mock_release_columns)
+
+    # call detailed_compare
+    result_compared_data = detailed_compare(spark_session, mock_prod_df, mock_release_df, [], ["age"])
+
+    # assert result
+    assert result_compared_data.matches()
+    assert isinstance(result_compared_data, SparkSQLCompare)
+
+
+def test_detailed_compare_with_column_to_join(spark_session):
+    # create mock data
+    mock_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
+    mock_prod_columns = ["name", "age", "pet"]
+
+    mock_release_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
+    mock_release_columns = ["name", "age", "pet"]
+
+    # Create DataFrames
+    mock_prod_df = spark_session.createDataFrame(mock_prod_data, mock_prod_columns)
+    mock_release_df = spark_session.createDataFrame(mock_release_data, mock_release_columns)
+
+    # call detailed_compare
+    result_compared_data = detailed_compare(spark_session, mock_prod_df, mock_release_df, ["name"], [])
+
+    # assert result
+    assert result_compared_data.matches()
+    assert isinstance(result_compared_data, SparkSQLCompare)
+
+
+def test_handle_numeric_strings(spark_session):
+    # create mock_df
+    mock_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
+    mock_columns = ["name", "age", "pet"]
+    mock_df = spark_session.createDataFrame(mock_data, mock_columns)
+
+    # create mock field_list
+    mock_field_list = ["age"]
+
+    # call handle_numeric_strings
+    result_df = handle_numeric_strings(mock_df, mock_field_list)
+
+    # create expected dataframe
+    expected_data = [("bob", 22.0, "dog"), ("alice", 19.0, "cat"), ("john", 70.0, "bunny")]
+    expected_columns = ["name", "age", "pet"]
+    expected_df = spark_session.createDataFrame(expected_data, expected_columns)
+
+    # assert calls
+    assert result_df.collect() == expected_df.collect()
+
+
+def test_format_numeric_fields(spark_session):
+    # create mock dataframe
+    mock_data = [("bob", 22, "dog"), ("alice", 19, "cat"), ("john", 70, "bunny")]
+    mock_columns = ["name", "age", "pet"]
+    mock_df = spark_session.createDataFrame(mock_data, mock_columns)
+
+    # call format_numeric_fields
+    formatted_df = format_numeric_fields(mock_df)
+
+    # create expected dataframe
+    expected_data = [("bob", "22.00000", "dog"), ("alice", "19.00000", "cat"), ("john", "70.00000", "bunny")]
+    expected_columns = ["name", "age", "pet"]
+    expected_df = spark_session.createDataFrame(expected_data, expected_columns)
+
+    # assert calls
+    assert formatted_df.collect() == expected_df.collect()
+
+
+def test_sort_rows_failure(spark_session):
+    # create mock dataframes
+    input_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
+    columns_prod = ["name", "age", "pet"]
+
+    input_release_data = [("19", "cat"), ("70", "bunny"), ("22", "dog")]
+    columns_release = ["age", "pet"]
+
+    # Create DataFrames
+    input_prod_df = spark_session.createDataFrame(input_prod_data, columns_prod)
+    input_release_df = spark_session.createDataFrame(input_release_data, columns_release)
+
+    # call call_rows
+    with pytest.raises(Exception, match="name is present in prod_df but does not exist in release_df"):
+        sort_rows(input_prod_df, input_release_df)
+
+
+def test_sort_rows_success(capsys, spark_session):
+    # create mock data
+    input_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
+    columns_prod = ["name", "age", "pet"]
+
+    input_release_data = [("19", "cat", "alice", "red"), ("70", "bunny", "john", "black"), ("22", "dog", "bob", "white")]
+    columns_release = ["age", "pet", "name", "color"]
+
+    # create dataFrames
+    input_prod_df = spark_session.createDataFrame(input_prod_data, columns_prod)
+    input_release_df = spark_session.createDataFrame(input_release_data, columns_release)
+
+    # call sort_rows
+    sorted_prod_df, sorted_release_df = sort_rows(input_prod_df, input_release_df)
+
+    # create expected prod_dataframe
+    expected_prod_data = [("alice", "19", "cat", 1), ("bob", "22", "dog", 2), ("john", "70", "bunny", 3)]
+    expected_prod_schema = StructType([
+        StructField("name", StringType(), True),
+        StructField("age", StringType(), True),
+        StructField("pet", StringType(), True),
+        StructField("row", IntegerType(), True)
+    ])
+    expected_prod_df = spark_session.createDataFrame(expected_prod_data, expected_prod_schema)
+
+    # create expected release_dataframe
+    expected_release_data = [("19", "cat", "alice", "red", 1), ("22", "dog", "bob", "white", 2), ("70", "bunny", "john", "black", 3)]
+    expected_release_schema = StructType([
+        StructField("age", StringType(), True),
+        StructField("pet", StringType(), True),
+        StructField("name", StringType(), True),
+        StructField("color", StringType(), True),
+        StructField("row", IntegerType(), True)
+    ])
+    expected_release_df = spark_session.createDataFrame(expected_release_data, expected_release_schema)
+
+    # assertions
+    assert sorted_prod_df.collect() == expected_prod_df.collect()
+    assert sorted_release_df.collect() == expected_release_df.collect()
+
+    captured = capsys.readouterr()
+    assert captured.out == "WARNING: There are columns present in Compare df that do not exist in Base df. The Base df columns will be used for row-wise sorting and may produce unanticipated report output if the extra fields are not null.\n"
+
+
+def test_sort_columns_failure(spark_session):
+    # create mock dataframes
+    input_prod_data = [("row1", "col2", "col3"), ("row2", "col2", "col3"), ("row3", "col2", "col3")]
+    columns_1 = ["col1", "col2", "col3"]
+
+    input_release_data = [("row1", "col2"), ("row2", "col2"), ("row3", "col2")]
+    columns_2 = ["col1", "col2"]
+
+    # Create DataFrames
+    input_prod_df = spark_session.createDataFrame(input_prod_data, columns_1)
+    input_release_df = spark_session.createDataFrame(input_release_data, columns_2)
+
+    # call sort_columns
+    with pytest.raises(Exception, match="col3 is present in prod_df but does not exist in release_df"):
+        sort_columns(input_prod_df, input_release_df)
+
+
+def test_sort_columns_success(spark_session):
+    # create mock dataframes
+    input_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")]
+    columns_prod = ["name", "age", "pet"]
+
+    input_release_data = [("19", "cat", "alice"), ("70", "bunny", "john"), ("22", "dog", "bob")]
+    columns_release = ["age", "pet", "name"]
+
+    # create input dataFrames
+    input_prod_df = spark_session.createDataFrame(input_prod_data, columns_prod)
+    input_release_df = spark_session.createDataFrame(input_release_data, columns_release)
+
+    # create expected dataFrames
+    expected_prod_data = [("alice", "19", "cat"), ("bob", "22", "dog"), ("john", "70", "bunny")]
+    expected_release_data = [("19", "cat", "alice"), ("22", "dog", "bob"), ("70", "bunny", "john")]
+    expected_prod_df = spark_session.createDataFrame(expected_prod_data, columns_prod)
+    expected_release_df = spark_session.createDataFrame(expected_release_data, columns_release)
+
+    # call sort_columns
+    output_prod_df, output_release_df = sort_columns(input_prod_df, input_release_df)
+
+    # assert the dfs are equal
+    assert output_prod_df.collect() == expected_prod_df.collect()
+    assert output_release_df.collect() == expected_release_df.collect()
\ No newline at end of file