diff --git a/datacompy/spark/sql.py b/datacompy/spark/sql.py index abcfea9e..4f3a507e 100644 --- a/datacompy/spark/sql.py +++ b/datacompy/spark/sql.py @@ -36,6 +36,7 @@ try: import pyspark.sql from pyspark.sql import Window + from pyspark.sql import types as T from pyspark.sql.functions import ( abs, array, @@ -49,6 +50,8 @@ trim, upper, when, + format_number, + row_number, ) except ImportError: LOG.warning( @@ -1221,5 +1224,118 @@ def _is_comparable(type1: str, type2: str) -> bool: or ({type1, type2} == {"string", "date"}) ) -def detailed_compare(): - return \ No newline at end of file + +def detailed_compare(spark_session, prod_dataframe, release_dataframe, column_to_join, string2double_cols=None): + """ + Uses DataComPy library to run a more detailed analysis on results + :param prod_dataframe: Spark Dataframe + dataset to be compared against + :param release_dataframe: Spark Dataframe + dataset to compare + :param column_to_join: List of Strings, optional + the column by which the two datasets can be joined, an identifier that indicates which rows in both datasets + should be compared + if null, the rows are compared in the order they are given + dataset to compare + :param string2double_cols: List of Strings, optional + the columns that contain numeric values but are stored as string types + :return: SparkSQLCompare object, from which a report can be generated + """ + + # Convert fields that contain numeric values stored as strings to numeric types for comparison + if len(string2double_cols) != 0: + prod_dataframe = handle_numeric_strings(prod_dataframe, string2double_cols) + release_dataframe = handle_numeric_strings(release_dataframe, string2double_cols) + + if len(column_to_join) == 0: + # will add a new column that numbers the rows so datasets can be compared by row number instead of by a + # common column + sorted_prod_df, sorted_release_df = sort_rows(prod_dataframe, release_dataframe) + column_to_join = ['row'] + else: + sorted_prod_df = prod_dataframe + sorted_release_df = release_dataframe + + print("Compared by column(s): ", column_to_join) + if string2double_cols: + print('String column(s) cast to doubles for numeric comparison: ', string2double_cols) + compared_data = SparkSQLCompare(spark_session, sorted_prod_df, sorted_release_df, + join_columns=column_to_join, abs_tol=0.0001) + return compared_data + + +def handle_numeric_strings(df, field_list): + for this_col in field_list: + df = df.withColumn(this_col, col(this_col).cast(T.DoubleType())) + return df + + +def format_numeric_fields(df): + fixed_cols = [] + numeric_types = [ + "tinyint", + "smallint", + "int", + "bigint", + "float", + "double", + "decimal"] + + for c in df.dtypes: + # do not change non-numeric fields + if c[1] not in numeric_types: + fixed_cols.append(col(c[0])) + # round & truncate numeric fields + else: + new_val = format_number(col(c[0]), 5).alias(c[0]) + fixed_cols.append(new_val) + + formatted_df = df.select(*fixed_cols) + return formatted_df + + +def sort_rows(prod_df, release_df): + prod_cols = prod_df.columns + release_cols = release_df.columns + + # Ensure both DataFrames have the same columns + for x in prod_cols: + if x not in release_cols: + raise Exception(f"{x} is present in prod_df but does not exist in release_df") + + if set(prod_cols) != set(release_cols): + print('WARNING: There are columns present in Compare df that do not exist in Base df. The Base df columns will be used for row-wise sorting and may produce unanticipated report output if the extra fields are not null.') + + w = Window.orderBy(*prod_cols) + sorted_prod_df = prod_df.select('*', row_number().over(w).alias('row')) + sorted_release_df = release_df.select('*', row_number().over(w).alias('row')) + return sorted_prod_df, sorted_release_df + + +def sort_columns(prod_df, release_df): + # Ensure both DataFrames have the same columns + common_columns = set(prod_df.columns) + for x in common_columns: + if x not in release_df.columns: + raise Exception(f"{x} is present in prod_df but does not exist in release_df") + # Sort both DataFrames to ensure consistent order + prod_df = prod_df.orderBy(*common_columns) + release_df = release_df.orderBy(*common_columns) + return prod_df, release_df + + +def convert_exponential_strings(base_df, compare_df): + # convert scientific number (1.23E4) to a decimal value + def sci_no_to_decimal(value): + return when(col(value).rlike(r"^[-+]?[0-9]*\.?[0-9]+[eE][0-9]+"), + col(value).cast(T.DecimalType(30, 10))).otherwise(col(value)) + + df_return_list = [] + + for df in [base_df, compare_df]: + for column in df.columns: + if column in df.columns and df.schema[column].dataType == T.StringType(): + df = df.withColumn(column, sci_no_to_decimal(column)) + df_return_list.append(df) + + return df_return_list[0], df_return_list[1] \ No newline at end of file diff --git a/tests/test_spark/test_sql_spark.py b/tests/test_spark/test_sql_spark.py index e93fbafc..b5ac7468 100644 --- a/tests/test_spark/test_sql_spark.py +++ b/tests/test_spark/test_sql_spark.py @@ -41,8 +41,15 @@ calculate_max_diff, columns_equal, temp_column_name, + sort_columns, + sort_rows, + format_numeric_fields, + handle_numeric_strings, + detailed_compare, ) from pandas.testing import assert_series_equal +from pyspark.sql.types import StringType, StructField, StructType, IntegerType, DoubleType, BooleanType, Row + logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) @@ -1344,3 +1351,186 @@ def test_unicode_columns(spark_session): assert compare.matches() # Just render the report to make sure it renders. compare.report() + + +def test_detailed_compare_with_string2columns(spark_session): + # create mock data + mock_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")] + mock_prod_columns = ["name", "age", "pet"] + + mock_release_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")] + mock_release_columns = ["name", "age", "pet"] + + # Create DataFrames + mock_prod_df = spark_session.createDataFrame(mock_prod_data, mock_prod_columns) + mock_release_df = spark_session.createDataFrame(mock_release_data, mock_release_columns) + + # call detailed_compare + result_compared_data = detailed_compare(spark_session, mock_prod_df, mock_release_df, [], ["age"]) + + # assert result + assert result_compared_data.matches() + assert isinstance(result_compared_data, SparkSQLCompare) + + +def test_detailed_compare_with_column_to_join(spark_session): + # create mock data + mock_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")] + mock_prod_columns = ["name", "age", "pet"] + + mock_release_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")] + mock_release_columns = ["name", "age", "pet"] + + # Create DataFrames + mock_prod_df = spark_session.createDataFrame(mock_prod_data, mock_prod_columns) + mock_release_df = spark_session.createDataFrame(mock_release_data, mock_release_columns) + + # call detailed_compare + result_compared_data = detailed_compare(spark_session, mock_prod_df, mock_release_df, ["name"], []) + + # assert result + assert result_compared_data.matches() + assert isinstance(result_compared_data, SparkSQLCompare) + + +def test_handle_numeric_strings(spark_session): + # create mock_df + mock_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")] + mock_columns = ["name", "age", "pet"] + mock_df = spark_session.createDataFrame(mock_data, mock_columns) + + # create mock field_list + mock_field_list = ["age"] + + # call handle_numeric_strings + result_df = handle_numeric_strings(mock_df, mock_field_list) + + # create expected dataframe + expected_data = [("bob", 22.0, "dog"), ("alice", 19.0, "cat"), ("john", 70.0, "bunny")] + expected_columns = ["name", "age", "pet"] + expected_df = spark_session.createDataFrame(expected_data, expected_columns) + + # assert calls + assert result_df.collect() == expected_df.collect() + + +def test_format_numeric_fields(spark_session): + # create mock dataframe + mock_data = [("bob", 22, "dog"), ("alice", 19, "cat"), ("john", 70, "bunny")] + mock_columns = ["name", "age", "pet"] + mock_df = spark_session.createDataFrame(mock_data, mock_columns) + + # call format_numeric_fields + formatted_df = format_numeric_fields(mock_df) + + # create expected dataframe + expected_data = [("bob", "22.00000", "dog"), ("alice", "19.00000", "cat"), ("john", "70.00000", "bunny")] + expected_columns = ["name", "age", "pet"] + expected_df = spark_session.createDataFrame(expected_data, expected_columns) + + # assert calls + assert formatted_df.collect() == expected_df.collect() + + +def test_sort_rows_failure(spark_session): + # create mock dataframes + input_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")] + columns_prod = ["name", "age", "pet"] + + input_release_data = [("19", "cat"), ("70", "bunny"), ("22", "dog")] + columns_release = ["age", "pet"] + + # Create DataFrames + input_prod_df = spark_session.createDataFrame(input_prod_data, columns_prod) + input_release_df = spark_session.createDataFrame(input_release_data, columns_release) + + # call call_rows + with pytest.raises(Exception, match="name is present in prod_df but does not exist in release_df"): + sort_rows(input_prod_df, input_release_df) + + +def test_sort_rows_success(capsys, spark_session): + # create mock data + input_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")] + columns_prod = ["name", "age", "pet"] + + input_release_data = [("19", "cat", "alice", "red"), ("70", "bunny", "john", "black"), ("22", "dog", "bob", "white")] + columns_release = ["age", "pet", "name", "color"] + + # create dataFrames + input_prod_df = spark_session.createDataFrame(input_prod_data, columns_prod) + input_release_df = spark_session.createDataFrame(input_release_data, columns_release) + + # call sort_rows + sorted_prod_df, sorted_release_df = sort_rows(input_prod_df, input_release_df) + + # create expected prod_dataframe + expected_prod_data = [("alice", "19", "cat", 1), ("bob", "22", "dog", 2), ("john", "70", "bunny", 3)] + expected_prod_schema = StructType([ + StructField("name", StringType(), True), + StructField("age", StringType(), True), + StructField("pet", StringType(), True), + StructField("row", IntegerType(), True) + ]) + expected_prod_df = spark_session.createDataFrame(expected_prod_data, expected_prod_schema) + + # create expected release_dataframe + expected_release_data = [("19", "cat", "alice", "red", 1), ("22", "dog", "bob", "white", 2), ("70", "bunny", "john", "black", 3)] + expected_release_schema = StructType([ + StructField("age", StringType(), True), + StructField("pet", StringType(), True), + StructField("name", StringType(), True), + StructField("color", StringType(), True), + StructField("row", IntegerType(), True) + ]) + expected_release_df = spark_session.createDataFrame(expected_release_data, expected_release_schema) + + # assertions + assert sorted_prod_df.collect() == expected_prod_df.collect() + assert sorted_release_df.collect() == expected_release_df.collect() + + captured = capsys.readouterr() + assert captured.out == "WARNING: There are columns present in Compare df that do not exist in Base df. The Base df columns will be used for row-wise sorting and may produce unanticipated report output if the extra fields are not null.\n" + + +def test_sort_columns_failure(spark_session): + # create mock dataframes + input_prod_data = [("row1", "col2", "col3"), ("row2", "col2", "col3"), ("row3", "col2", "col3")] + columns_1 = ["col1", "col2", "col3"] + + input_release_data = [("row1", "col2"), ("row2", "col2"), ("row3", "col2")] + columns_2 = ["col1", "col2"] + + # Create DataFrames + input_prod_df = spark_session.createDataFrame(input_prod_data, columns_1) + input_release_df = spark_session.createDataFrame(input_release_data, columns_2) + + # call sort_columns + with pytest.raises(Exception, match="col3 is present in prod_df but does not exist in release_df"): + sort_columns(input_prod_df, input_release_df) + + +def test_sort_columns_success(spark_session): + # create mock dataframes + input_prod_data = [("bob", "22", "dog"), ("alice", "19", "cat"), ("john", "70", "bunny")] + columns_prod = ["name", "age", "pet"] + + input_release_data = [("19", "cat", "alice"), ("70", "bunny", "john"), ("22", "dog", "bob")] + columns_release = ["age", "pet", "name"] + + # create input dataFrames + input_prod_df = spark_session.createDataFrame(input_prod_data, columns_prod) + input_release_df = spark_session.createDataFrame(input_release_data, columns_release) + + # create expected dataFrames + expected_prod_data = [("alice", "19", "cat"), ("bob", "22", "dog"), ("john", "70", "bunny")] + expected_release_data = [("19", "cat", "alice"), ("22", "dog", "bob"), ("70", "bunny", "john")] + expected_prod_df = spark_session.createDataFrame(expected_prod_data, columns_prod) + expected_release_df = spark_session.createDataFrame(expected_release_data, columns_release) + + # call sort_columns + output_prod_df, output_release_df = sort_columns(input_prod_df, input_release_df) + + # assert the dfs are equal + assert output_prod_df.collect() == expected_prod_df.collect() + assert output_release_df.collect() == expected_release_df.collect() \ No newline at end of file