Skip to content

Commit

Permalink
Merge pull request #27 from defog-ai/wendy/tsql_features
Browse files Browse the repository at this point in the history
TSQL features
  • Loading branch information
wendy-aw authored Aug 19, 2024
2 parents ae36c54 + ddffbf2 commit d0f0a2b
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
19 changes: 17 additions & 2 deletions defog_utils/utils_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class SqlFeatures(Features):
has_in: bool = False
month_name_case_in: bool = False
addition: bool = False
date_add: bool = False
subtraction: bool = False
ratio: bool = False
round: bool = False
Expand Down Expand Up @@ -120,7 +121,7 @@ class SqlFeatures(Features):
exp.CurrentTime,
exp.CurrentTimestamp,
]
date_time_types = ["DATE", "TIMESTAMP"]
date_time_types = ["DATE", "DATETIME", "TIMESTAMP"]
int_types = ["INT", "INTEGER", "BIGINT", "SMALLINT", "UINT", "UBIGINT"]
comparison_expressions = [
# binary op with 2 children
Expand Down Expand Up @@ -363,7 +364,15 @@ def get_sql_features(
if has_month_name(str(node)):
features.month_name_case_in = True
elif isinstance(node, exp.Add):
for subnode in node.flatten():
if isinstance(subnode, exp.Literal):
features.string_concat = True
features.addition = True
elif isinstance(node, exp.DateAdd):
features.date_add = True
for sub_node in node.flatten():
if isinstance(sub_node, exp.Neg):
features.date_sub = True
elif isinstance(node, exp.Sub):
features.subtraction = True
date_cols_in_sub = 0
Expand Down Expand Up @@ -429,10 +438,16 @@ def get_sql_features(
features.date_sub = True
elif isinstance(node, exp.DateTrunc) or isinstance(node, exp.TimestampTrunc):
features.date_trunc = True
elif isinstance(node, exp.Convert):
for sub_node in node.flatten():
if str(sub_node) in date_time_types:
features.date_time_type_conversion = True
elif isinstance(node, exp.StrToDate):
features.date_time_type_conversion = True
elif isinstance(node, exp.StrToTime):
features.date_time_type_conversion = True
elif isinstance(node, exp.DateFromParts):
features.date_time_type_conversion = True
elif isinstance(node, exp.Extract):
features.date_part = True
elif type(node) in current_date_time_expressions:
Expand Down Expand Up @@ -493,7 +508,7 @@ def get_sql_features(
or node_name == "percent_rank"
):
features.rank = True
elif node_name == "date_part":
elif node_name == "date_part" or node_name == "datepart":
features.date_part = True
elif node_name == "strftime":
features.strftime = True
Expand Down
33 changes: 31 additions & 2 deletions tests/test_utils_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ def test_date_sub_date(self):
self.assertFalse(features.date_sub_date)
self.assertFalse(features.date_sub)

# DATEDIFF(date1, date2)
sql = "SELECT DATEDIFF(column1, column2) FROM table"
features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="tsql")
self.assertTrue(features.date_sub_date)
self.assertFalse(features.date_sub)

def test_current_date_time(self):
sql = "SELECT col_date - CURRENT_DATE FROM table"
features = get_sql_features(sql, self.md_cols, self.md_tables)
Expand Down Expand Up @@ -338,6 +344,21 @@ def test_date_time_type_conversion(self):
sql, self.md_cols, self.md_tables, dialect="postgres"
)
self.assertTrue(features.date_time_type_conversion)
sql = "SELECT CONVERT(DATE, date_str_column) FROM table"
features = get_sql_features(
sql, self.md_cols, self.md_tables, dialect="tsql"
)
self.assertTrue(features.date_time_type_conversion)
sql = "SELECT CONVERT(INT, col) FROM table"
features = get_sql_features(
sql, self.md_cols, self.md_tables, dialect="tsql"
)
self.assertFalse(features.date_time_type_conversion)
sql = "SELECT DATEFROMPARTS(year_column, month_column, day_column) FROM table"
features = get_sql_features(
sql, self.md_cols, self.md_tables, dialect="tsql"
)
self.assertTrue(features.date_time_type_conversion)

def test_date_time_format(self):
sql = "SELECT TO_CHAR(column, 'YYYY-MM-DD') FROM table"
Expand All @@ -352,6 +373,11 @@ def test_date_time_format(self):
)
self.assertFalse(features.date_time_format)
self.assertFalse(features.strftime)
sql = "SELECT FORMAT(column, 'YYYY-MM-DD') FROM table"
features = get_sql_features(
sql, self.md_cols, self.md_tables, dialect="tsql"
)
self.assertTrue(features.date_time_format)

def test_generate_timeseries(self):
sql = "SELECT generate_series(1, 10)"
Expand All @@ -374,6 +400,9 @@ def test_string_concat(self):
sql = "SELECT name || ' ' || description FROM table1"
features = get_sql_features(sql, self.md_cols, self.md_tables)
self.assertTrue(features.string_concat)
sql = "SELECT name + ' ' + description FROM table1"
features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="tsql")
self.assertTrue(features.string_concat)

def test_string_exact_match(self):
sql = "SELECT * FROM table1 WHERE name = 'Exact Match'"
Expand Down Expand Up @@ -587,7 +616,7 @@ def test_complex_sql_1(self):
features = get_sql_features(sql, self.md_cols, self.md_tables)
features_compact = features.compact()
print(features_compact)
expected_compact = "5,2,1,1,0,1,1,1,1,0,0,0,0,0,1,0,0,1,1,1,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0"
expected_compact = "5,2,1,1,0,1,1,1,1,0,0,0,0,0,0,1,0,0,1,1,1,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0"
self.assertEqual(features_compact, expected_compact)
positive_features = features.positive_features()
expected_positive = {
Expand Down Expand Up @@ -626,7 +655,7 @@ def test_complex_sql_2(self):
)
features_compact = features.compact()
print(features_compact)
expected_compact = "3,1,1,1,0,1,0,0,1,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0"
expected_compact = "3,1,1,1,0,1,0,0,1,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0"
self.assertEqual(features_compact, expected_compact)
positive_features = features.positive_features()
expected_positive = {
Expand Down

0 comments on commit d0f0a2b

Please sign in to comment.