diff --git a/defog_utils/utils_sql.py b/defog_utils/utils_sql.py index 47797e5..9ed2f31 100644 --- a/defog_utils/utils_sql.py +++ b/defog_utils/utils_sql.py @@ -58,6 +58,7 @@ class SqlFeatures(Features): has_in: bool = False month_name_case_in: bool = False addition: bool = False + date_add: bool = False subtraction: bool = False ratio: bool = False round: bool = False @@ -120,7 +121,7 @@ class SqlFeatures(Features): exp.CurrentTime, exp.CurrentTimestamp, ] -date_time_types = ["DATE", "TIMESTAMP"] +date_time_types = ["DATE", "DATETIME", "TIMESTAMP"] int_types = ["INT", "INTEGER", "BIGINT", "SMALLINT", "UINT", "UBIGINT"] comparison_expressions = [ # binary op with 2 children @@ -363,7 +364,15 @@ def get_sql_features( if has_month_name(str(node)): features.month_name_case_in = True elif isinstance(node, exp.Add): + for subnode in node.flatten(): + if isinstance(subnode, exp.Literal): + features.string_concat = True features.addition = True + elif isinstance(node, exp.DateAdd): + features.date_add = True + for sub_node in node.flatten(): + if isinstance(sub_node, exp.Neg): + features.date_sub = True elif isinstance(node, exp.Sub): features.subtraction = True date_cols_in_sub = 0 @@ -429,10 +438,16 @@ def get_sql_features( features.date_sub = True elif isinstance(node, exp.DateTrunc) or isinstance(node, exp.TimestampTrunc): features.date_trunc = True + elif isinstance(node, exp.Convert): + for sub_node in node.flatten(): + if str(sub_node) in date_time_types: + features.date_time_type_conversion = True elif isinstance(node, exp.StrToDate): features.date_time_type_conversion = True elif isinstance(node, exp.StrToTime): features.date_time_type_conversion = True + elif isinstance(node, exp.DateFromParts): + features.date_time_type_conversion = True elif isinstance(node, exp.Extract): features.date_part = True elif type(node) in current_date_time_expressions: @@ -493,7 +508,7 @@ def get_sql_features( or node_name == "percent_rank" ): features.rank = True - elif node_name == "date_part": + elif node_name == "date_part" or node_name == "datepart": features.date_part = True elif node_name == "strftime": features.strftime = True diff --git a/tests/test_utils_sql.py b/tests/test_utils_sql.py index d120745..f31d6bb 100644 --- a/tests/test_utils_sql.py +++ b/tests/test_utils_sql.py @@ -280,6 +280,12 @@ def test_date_sub_date(self): self.assertFalse(features.date_sub_date) self.assertFalse(features.date_sub) + # DATEDIFF(date1, date2) + sql = "SELECT DATEDIFF(column1, column2) FROM table" + features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="tsql") + self.assertTrue(features.date_sub_date) + self.assertFalse(features.date_sub) + def test_current_date_time(self): sql = "SELECT col_date - CURRENT_DATE FROM table" features = get_sql_features(sql, self.md_cols, self.md_tables) @@ -338,6 +344,21 @@ def test_date_time_type_conversion(self): sql, self.md_cols, self.md_tables, dialect="postgres" ) self.assertTrue(features.date_time_type_conversion) + sql = "SELECT CONVERT(DATE, date_str_column) FROM table" + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="tsql" + ) + self.assertTrue(features.date_time_type_conversion) + sql = "SELECT CONVERT(INT, col) FROM table" + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="tsql" + ) + self.assertFalse(features.date_time_type_conversion) + sql = "SELECT DATEFROMPARTS(year_column, month_column, day_column) FROM table" + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="tsql" + ) + self.assertTrue(features.date_time_type_conversion) def test_date_time_format(self): sql = "SELECT TO_CHAR(column, 'YYYY-MM-DD') FROM table" @@ -352,6 +373,11 @@ def test_date_time_format(self): ) self.assertFalse(features.date_time_format) self.assertFalse(features.strftime) + sql = "SELECT FORMAT(column, 'YYYY-MM-DD') FROM table" + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="tsql" + ) + self.assertTrue(features.date_time_format) def test_generate_timeseries(self): sql = "SELECT generate_series(1, 10)" @@ -374,6 +400,9 @@ def test_string_concat(self): sql = "SELECT name || ' ' || description FROM table1" features = get_sql_features(sql, self.md_cols, self.md_tables) self.assertTrue(features.string_concat) + sql = "SELECT name + ' ' + description FROM table1" + features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="tsql") + self.assertTrue(features.string_concat) def test_string_exact_match(self): sql = "SELECT * FROM table1 WHERE name = 'Exact Match'" @@ -587,7 +616,7 @@ def test_complex_sql_1(self): features = get_sql_features(sql, self.md_cols, self.md_tables) features_compact = features.compact() print(features_compact) - expected_compact = "5,2,1,1,0,1,1,1,1,0,0,0,0,0,1,0,0,1,1,1,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0" + expected_compact = "5,2,1,1,0,1,1,1,1,0,0,0,0,0,0,1,0,0,1,1,1,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0" self.assertEqual(features_compact, expected_compact) positive_features = features.positive_features() expected_positive = { @@ -626,7 +655,7 @@ def test_complex_sql_2(self): ) features_compact = features.compact() print(features_compact) - expected_compact = "3,1,1,1,0,1,0,0,1,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0" + expected_compact = "3,1,1,1,0,1,0,0,1,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0" self.assertEqual(features_compact, expected_compact) positive_features = features.positive_features() expected_positive = {