Skip to content

Commit

Permalink
fixing a few edge cases while shuffling
Browse files Browse the repository at this point in the history
  • Loading branch information
wongjingping committed Jun 10, 2024
1 parent abafa11 commit 9d7660c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 11 deletions.
26 changes: 16 additions & 10 deletions defog_utils/utils_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,22 +565,28 @@ def fix_comma(cols: List[str]) -> List[str]:
if "," not in col[: col.index("--")]:
# use re.sub to replace (any whitespace)-- with , --
col = re.sub(r"\s*--", ", --", col)
# check if string ends with comma
elif "," not in col:
# check if string ends with comma (optionally with additional spaces)
elif not re.search(r",\s*$", col):
# replace all trailing spaces with ,
col = re.sub(r"\s+$", ",", col)
fixed_cols.append(col)
# for the last col, we want to remove the comma
last_col = fixed_cols[-1]
if "--" in last_col:
# check if comma is before comment and remove if present
last_col_split = last_col.split("--", 1)
if "," in last_col_split[0]:
last_col = (
"".join(last_col_split[0].split(",", 1)) + "--" + last_col_split[1]
)
elif "," in last_col:
last_col = "".join(last_col.split(",", 1))
# check if comma is after a word/closing brace, followed by spaces before -- and remove if present

pre_comment, after_comment = last_col.split("--", 1)
# check if pre_comment ends with a comma with optional spaces
if re.search(r",\s*$", pre_comment):
pre_comment = re.sub(r",\s*$", "", pre_comment)
# remove any trailing spaces in pre_comment
pre_comment = pre_comment.rstrip()
last_col = pre_comment + " --" + after_comment
# if last_col ends with a comma with optional spaces, remove it
elif re.search(r",\s*$", last_col):
last_col = re.sub(r",\s*$", "", last_col)
# remove any trailing spaces in last_col
last_col = last_col.rstrip()
fixed_cols[-1] = last_col
return fixed_cols

Expand Down
30 changes: 29 additions & 1 deletion tests/test_utils_sql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
from defog_utils.defog_utils.utils_sql import (
add_space_padding,
fix_comma,
get_schema_features,
get_sql_features,
is_date_or_time_str,
Expand Down Expand Up @@ -821,6 +822,23 @@ def test_schema_1(self):
self.assertEqual(positive_features, expected_positive)


class TestFixComma(unittest.TestCase):
def test_fix_comma_1(self):
cols = [
" CUSTOMER_EMAIL VARCHAR,",
" CUSTOMER_PHONE VARCHAR(200) --Phone number of the customer", # add comma
" value numeric(10,2),", # remove trailing comma
]
expected = [
" CUSTOMER_EMAIL VARCHAR,",
" CUSTOMER_PHONE VARCHAR(200), --Phone number of the customer",
" value numeric(10,2)",
]
result = fix_comma(cols)
print(result)
self.assertEqual(result, expected)


class TestShuffleTableMetadata(unittest.TestCase):
def test_shuffle_table_metadata_seed_1(self):
input_md_str = """CREATE SCHEMA IF NOT EXISTS TEST_DB;
Expand Down Expand Up @@ -874,10 +892,20 @@ def test_shuffle_table_metadata_seed_1(self):
def test_shuffle_table_metadata_seed_2(self):
input_md_str = """CREATE TABLE branch_info (
branch_open_date date, --Date branch opened
value numeric(10,2),
manager_name varchar(100) --Name of the branch manager
);
CREATE TABLE employee (
employee_id integer,
ytd_return numeric(5,2)
);"""
expected_md_shuffled = """CREATE TABLE branch_info (
expected_md_shuffled = """CREATE TABLE employee (
employee_id integer,
ytd_return numeric(5,2)
);
CREATE TABLE branch_info (
manager_name varchar(100), --Name of the branch manager
value numeric(10,2),
branch_open_date date --Date branch opened
);"""
md_shuffled = shuffle_table_metadata(input_md_str, 0)
Expand Down

0 comments on commit 9d7660c

Please sign in to comment.