Skip to content

Commit 9d7660c

Browse files
committed
fixing a few edge cases while shuffling
1 parent abafa11 commit 9d7660c

File tree

2 files changed

+45
-11
lines changed

2 files changed

+45
-11
lines changed

defog_utils/utils_sql.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -565,22 +565,28 @@ def fix_comma(cols: List[str]) -> List[str]:
565565
if "," not in col[: col.index("--")]:
566566
# use re.sub to replace (any whitespace)-- with , --
567567
col = re.sub(r"\s*--", ", --", col)
568-
# check if string ends with comma
569-
elif "," not in col:
568+
# check if string ends with comma (optionally with additional spaces)
569+
elif not re.search(r",\s*$", col):
570570
# replace all trailing spaces with ,
571571
col = re.sub(r"\s+$", ",", col)
572572
fixed_cols.append(col)
573573
# for the last col, we want to remove the comma
574574
last_col = fixed_cols[-1]
575575
if "--" in last_col:
576-
# check if comma is before comment and remove if present
577-
last_col_split = last_col.split("--", 1)
578-
if "," in last_col_split[0]:
579-
last_col = (
580-
"".join(last_col_split[0].split(",", 1)) + "--" + last_col_split[1]
581-
)
582-
elif "," in last_col:
583-
last_col = "".join(last_col.split(",", 1))
576+
# check if comma is after a word/closing brace, followed by spaces before -- and remove if present
577+
578+
pre_comment, after_comment = last_col.split("--", 1)
579+
# check if pre_comment ends with a comma with optional spaces
580+
if re.search(r",\s*$", pre_comment):
581+
pre_comment = re.sub(r",\s*$", "", pre_comment)
582+
# remove any trailing spaces in pre_comment
583+
pre_comment = pre_comment.rstrip()
584+
last_col = pre_comment + " --" + after_comment
585+
# if last_col ends with a comma with optional spaces, remove it
586+
elif re.search(r",\s*$", last_col):
587+
last_col = re.sub(r",\s*$", "", last_col)
588+
# remove any trailing spaces in last_col
589+
last_col = last_col.rstrip()
584590
fixed_cols[-1] = last_col
585591
return fixed_cols
586592

tests/test_utils_sql.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
from defog_utils.defog_utils.utils_sql import (
33
add_space_padding,
4+
fix_comma,
45
get_schema_features,
56
get_sql_features,
67
is_date_or_time_str,
@@ -821,6 +822,23 @@ def test_schema_1(self):
821822
self.assertEqual(positive_features, expected_positive)
822823

823824

825+
class TestFixComma(unittest.TestCase):
826+
def test_fix_comma_1(self):
827+
cols = [
828+
" CUSTOMER_EMAIL VARCHAR,",
829+
" CUSTOMER_PHONE VARCHAR(200) --Phone number of the customer", # add comma
830+
" value numeric(10,2),", # remove trailing comma
831+
]
832+
expected = [
833+
" CUSTOMER_EMAIL VARCHAR,",
834+
" CUSTOMER_PHONE VARCHAR(200), --Phone number of the customer",
835+
" value numeric(10,2)",
836+
]
837+
result = fix_comma(cols)
838+
print(result)
839+
self.assertEqual(result, expected)
840+
841+
824842
class TestShuffleTableMetadata(unittest.TestCase):
825843
def test_shuffle_table_metadata_seed_1(self):
826844
input_md_str = """CREATE SCHEMA IF NOT EXISTS TEST_DB;
@@ -874,10 +892,20 @@ def test_shuffle_table_metadata_seed_1(self):
874892
def test_shuffle_table_metadata_seed_2(self):
875893
input_md_str = """CREATE TABLE branch_info (
876894
branch_open_date date, --Date branch opened
895+
value numeric(10,2),
877896
manager_name varchar(100) --Name of the branch manager
897+
);
898+
CREATE TABLE employee (
899+
employee_id integer,
900+
ytd_return numeric(5,2)
878901
);"""
879-
expected_md_shuffled = """CREATE TABLE branch_info (
902+
expected_md_shuffled = """CREATE TABLE employee (
903+
employee_id integer,
904+
ytd_return numeric(5,2)
905+
);
906+
CREATE TABLE branch_info (
880907
manager_name varchar(100), --Name of the branch manager
908+
value numeric(10,2),
881909
branch_open_date date --Date branch opened
882910
);"""
883911
md_shuffled = shuffle_table_metadata(input_md_str, 0)

0 commit comments

Comments
 (0)