Skip to content

Commit

Permalink
remove unnecessary helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Ally Franken committed Jan 15, 2025
1 parent d3fa481 commit 80c89c8
Showing 1 changed file with 20 additions and 37 deletions.
57 changes: 20 additions & 37 deletions datacompy/spark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,30 +1270,6 @@ def handle_numeric_strings(df, field_list):
return df


def format_numeric_fields(df):
fixed_cols = []
numeric_types = [
"tinyint",
"smallint",
"int",
"bigint",
"float",
"double",
"decimal"]

for c in df.dtypes:
# do not change non-numeric fields
if c[1] not in numeric_types:
fixed_cols.append(col(c[0]))
# round & truncate numeric fields
else:
new_val = format_number(col(c[0]), 5).alias(c[0])
fixed_cols.append(new_val)

formatted_df = df.select(*fixed_cols)
return formatted_df


def sort_rows(prod_df, release_df):
prod_cols = prod_df.columns
release_cols = release_df.columns
Expand Down Expand Up @@ -1324,18 +1300,25 @@ def sort_columns(prod_df, release_df):
return prod_df, release_df


def convert_exponential_strings(base_df, compare_df):
# convert scientific number (1.23E4) to a decimal value
def sci_no_to_decimal(value):
return when(col(value).rlike(r"^[-+]?[0-9]*\.?[0-9]+[eE][0-9]+"),
col(value).cast(T.DecimalType(30, 10))).otherwise(col(value))

df_return_list = []
def format_numeric_fields(df):
fixed_cols = []
numeric_types = [
"tinyint",
"smallint",
"int",
"bigint",
"float",
"double",
"decimal"]

for df in [base_df, compare_df]:
for column in df.columns:
if column in df.columns and df.schema[column].dataType == T.StringType():
df = df.withColumn(column, sci_no_to_decimal(column))
df_return_list.append(df)
for c in df.dtypes:
# do not change non-numeric fields
if c[1] not in numeric_types:
fixed_cols.append(col(c[0]))
# round & truncate numeric fields
else:
new_val = format_number(col(c[0]), 5).alias(c[0])
fixed_cols.append(new_val)

return df_return_list[0], df_return_list[1]
formatted_df = df.select(*fixed_cols)
return formatted_df

0 comments on commit 80c89c8

Please sign in to comment.