diff --git a/defog_utils/utils_sql.py b/defog_utils/utils_sql.py index 978b862..fa53de7 100644 --- a/defog_utils/utils_sql.py +++ b/defog_utils/utils_sql.py @@ -563,10 +563,12 @@ def fix_comma(cols: List[str]) -> List[str]: if "--" in col: # check if comma is before comment if "," not in col[: col.index("--")]: - col = col.replace("--", ", --") + # use re.sub to replace (any whitespace)-- with , -- + col = re.sub(r"\s*--", ", --", col) # check if string ends with comma elif "," not in col: - col += "," + # replace all trailing spaces with , + col = re.sub(r"\s+$", ",", col) fixed_cols.append(col) # for the last col, we want to remove the comma last_col = fixed_cols[-1] diff --git a/tests/test_utils_sql.py b/tests/test_utils_sql.py index ec98e09..960e73e 100644 --- a/tests/test_utils_sql.py +++ b/tests/test_utils_sql.py @@ -822,10 +822,10 @@ def test_schema_1(self): class TestShuffleTableMetadata(unittest.TestCase): - def test_shuffle_table_metadata_seed(self): + def test_shuffle_table_metadata_seed_1(self): input_md_str = """CREATE SCHEMA IF NOT EXISTS TEST_DB; CREATE TABLE TEST_DB.PUBLIC.CUSTOMERS ( - CUSTOMER_EMAIL VARCHAR, --Email address of the customer + CUSTOMER_EMAIL VARCHAR, CUSTOMER_PHONE VARCHAR, --Phone number of the customer CUSTOMER_ID NUMERIC, --Unique identifier for each customer CUSTOMER_NAME VARCHAR --Name of the customer @@ -855,8 +855,8 @@ def test_shuffle_table_metadata_seed(self): ); CREATE TABLE TEST_DB.PUBLIC.CUSTOMERS ( CUSTOMER_PHONE VARCHAR, --Phone number of the customer - CUSTOMER_NAME VARCHAR , --Name of the customer - CUSTOMER_EMAIL VARCHAR, --Email address of the customer + CUSTOMER_NAME VARCHAR, --Name of the customer + CUSTOMER_EMAIL VARCHAR, CUSTOMER_ID NUMERIC --Unique identifier for each customer ); CREATE TABLE patient ( @@ -870,6 +870,19 @@ def test_shuffle_table_metadata_seed(self): md_shuffled = shuffle_table_metadata(input_md_str, 42) self.maxDiff = None self.assertEqual(md_shuffled, expected_md_shuffled) + + def test_shuffle_table_metadata_seed_2(self): + input_md_str = """CREATE TABLE branch_info ( + branch_open_date date, --Date branch opened + manager_name varchar(100) --Name of the branch manager +);""" + expected_md_shuffled = """CREATE TABLE branch_info ( + manager_name varchar(100), --Name of the branch manager + branch_open_date date --Date branch opened +);""" + md_shuffled = shuffle_table_metadata(input_md_str, 0) + print(md_shuffled) + self.assertEqual(md_shuffled, expected_md_shuffled) class TestFunctions(unittest.TestCase):