From 1c3048812aa77ed44f2e3fb063f65229fe42461d Mon Sep 17 00:00:00 2001 From: wendy Date: Mon, 15 Jul 2024 16:37:02 +0800 Subject: [PATCH 1/3] - add new sql feature month_name_case_in - modify fix_comma to accommodate comments startingn with /* instead of -- --- defog_utils/utils_sql.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/defog_utils/utils_sql.py b/defog_utils/utils_sql.py index ecdb554..f1c2645 100644 --- a/defog_utils/utils_sql.py +++ b/defog_utils/utils_sql.py @@ -56,6 +56,7 @@ class SqlFeatures(Features): union: bool = False case_condition: bool = False has_in: bool = False + month_name_case_in: bool = False addition: bool = False subtraction: bool = False ratio: bool = False @@ -106,6 +107,7 @@ class SqlFeatures(Features): time_pattern = r"^(0\d|1\d|2[0-3]):([0-5]\d):([0-5]\d)" date_or_time_pattern = f"({date_pattern}|{time_pattern})" date_column_pattern = r"(date|timestamp)(\s|$)" +month_name_pattern = r"('Jan'|'Feb'|'Mar'|'Apr'|'May'|'Jun'|'Jul'|'Aug'|'Sep'|'Oct'|'Nov'|'Dec'|'January'|'February'|'March'|'April'|'May'|'June'|'July'|'August'|'September'|'October'|'November'|'December')" variance_expressions = [ exp.VariancePop, exp.Variance, @@ -354,8 +356,12 @@ def get_sql_features( features.union = True elif isinstance(node, exp.Case): features.case_condition = True + if has_month_name(str(node)): + features.month_name_case_in = True elif isinstance(node, exp.In): features.has_in = True + if has_month_name(str(node)): + features.month_name_case_in = True elif isinstance(node, exp.Add): features.addition = True elif isinstance(node, exp.Sub): @@ -514,6 +520,8 @@ def is_date_or_time_str(s: str) -> bool: m = re.match(date_or_time_pattern, s) return bool(m) +def has_month_name(s: str) -> bool: + return bool(re.search(month_name_pattern, s, re.IGNORECASE)) def has_date_in_name(s: str) -> bool: return bool(re.search(r"(year|quarter|month|week|day)", s)) @@ -605,6 +613,11 @@ def fix_comma(cols: List[str]) -> List[str]: if not re.search(r",\s*--", col): # use re.sub to replace (any whitespace)-- with , -- col = re.sub(r"\s*--", ", --", col) + elif "/*" in col: + # check if comma is just before comment + if not re.search(r",\s*/\*", col): + # use re.sub to replace (any whitespace)-- with , -- + col = re.sub(r"\s*/\*", ", /*", col) # check if string ends with comma (optionally with additional spaces) elif not re.search(r",\s*$", col): # end with comma if not present @@ -614,7 +627,6 @@ def fix_comma(cols: List[str]) -> List[str]: last_col = fixed_cols[-1] if "--" in last_col: # check if comma is after a word/closing brace, followed by spaces before -- and remove if present - pre_comment, after_comment = last_col.split("--", 1) # check if pre_comment ends with a comma with optional spaces if re.search(r",\s*$", pre_comment): @@ -622,6 +634,15 @@ def fix_comma(cols: List[str]) -> List[str]: # remove any trailing spaces in pre_comment pre_comment = pre_comment.rstrip() last_col = pre_comment + " --" + after_comment + elif "/*" in last_col: + # check if comma is after a word/closing brace, followed by spaces before -- and remove if present + pre_comment, after_comment = last_col.split("/*", 1) + # check if pre_comment ends with a comma with optional spaces + if re.search(r",\s*$", pre_comment): + pre_comment = re.sub(r",\s*$", "", pre_comment) + # remove any trailing spaces in pre_comment + pre_comment = pre_comment.rstrip() + last_col = pre_comment + " /*" + after_comment # if last_col ends with a comma with optional spaces, remove it elif re.search(r",\s*$", last_col): last_col = re.sub(r",\s*$", "", last_col) From 1c4d7ac186975f7de3fea5e7441899b26c0b77b0 Mon Sep 17 00:00:00 2001 From: wendy Date: Mon, 15 Jul 2024 16:40:22 +0800 Subject: [PATCH 2/3] modify test results --- tests/test_utils_sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_utils_sql.py b/tests/test_utils_sql.py index 356c227..d90a9b2 100644 --- a/tests/test_utils_sql.py +++ b/tests/test_utils_sql.py @@ -542,7 +542,7 @@ def test_complex_sql_1(self): features = get_sql_features(sql, self.md_cols, self.md_tables) features_compact = features.compact() print(features_compact) - expected_compact = "5,2,1,1,0,1,1,1,1,0,0,0,0,1,0,0,1,1,1,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0" + expected_compact = "5,2,1,1,0,1,1,1,1,0,0,0,0,0,1,0,0,1,1,1,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0" self.assertEqual(features_compact, expected_compact) positive_features = features.positive_features() expected_positive = { @@ -581,7 +581,7 @@ def test_complex_sql_2(self): ) features_compact = features.compact() print(features_compact) - expected_compact = "3,1,1,1,0,1,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0" + expected_compact = "3,1,1,1,0,1,0,0,1,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0" self.assertEqual(features_compact, expected_compact) positive_features = features.positive_features() expected_positive = { From f96a6989cf8568a48edb6c5d0bfbb917640e4b9a Mon Sep 17 00:00:00 2001 From: wendy Date: Mon, 15 Jul 2024 17:10:30 +0800 Subject: [PATCH 3/3] added test for month_name_case_in --- tests/test_utils_sql.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_utils_sql.py b/tests/test_utils_sql.py index d90a9b2..7b1360b 100644 --- a/tests/test_utils_sql.py +++ b/tests/test_utils_sql.py @@ -124,6 +124,17 @@ def test_case_condition(self): features = get_sql_features(sql, self.md_cols, self.md_tables) self.assertTrue(features.case_condition) + def test_month_name_case_in(self): + sql = "SELECT * FROM table WHERE month_col IN ('January', 'February', 'Mar', 'Apr')" + sql2 = "SELECT * FROM table WHERE month_col = 'January'" + sql3 = "SELECT * FROM review r WHERE DATE(CAST(r.year AS TEXT) || '-' || CASE r.month WHEN 'January' THEN '01' WHEN 'February' THEN '02' WHEN 'March' THEN '03' WHEN 'April' THEN '04' WHEN 'May' THEN '05' WHEN 'June' THEN '06' WHEN 'July' THEN '07' WHEN 'August' THEN '08' WHEN 'September' THEN '09' WHEN 'October' THEN '10' WHEN 'November' THEN '11' WHEN 'December' THEN '12' END || '-01') >= DATE('now', '-12 months');" + features = get_sql_features(sql, self.md_cols, self.md_tables) + features2 = get_sql_features(sql2, self.md_cols, self.md_tables) + features3 = get_sql_features(sql3, self.md_cols, self.md_tables) + self.assertTrue(features.month_name_case_in) + self.assertFalse(features2.month_name_case_in) + self.assertTrue(features3.month_name_case_in) + def test_ratio(self): sql = "SELECT column1 / column2 FROM table" features = get_sql_features(sql, self.md_cols, self.md_tables)