From 1bd9674491fd02b69eb8d79ba03ac1d330f0cef0 Mon Sep 17 00:00:00 2001 From: jp Date: Mon, 5 Aug 2024 09:29:13 +0800 Subject: [PATCH 1/4] add vscode settings to facilitate easier testing add additional test using column names with spaces --- .vscode/settings.json | 10 ++++++ tests/test_utils_db.py | 71 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..8ed99fa --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,10 @@ +{ + "python.analysis.extraPaths": [ + "./defog_utils" + ], + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, +} \ No newline at end of file diff --git a/tests/test_utils_db.py b/tests/test_utils_db.py index 99c93fe..9aacee5 100644 --- a/tests/test_utils_db.py +++ b/tests/test_utils_db.py @@ -51,6 +51,77 @@ def test_mk_create_table_ddl(self): ");\n" ) self.assertEqual(mk_create_table_ddl(table_name, columns), expected_output) + + def test_mk_create_table_ddl_spaces(self): + table_name = "table1" + columns = [ + { + "data_type": "text", + "column_name": "Invoice Number", + "column_description": "Unique identifier for each invoice" + }, + { + "data_type": "text", + "column_name": "Invoice Date", + "column_description": "Date when the invoice was issued" + }, + { + "data_type": "text", + "column_name": "Sales Order#", + "column_description": "Sales order number associated with the invoice" + }, + { + "data_type": "text", + "column_name": "Customer Name", + "column_description": "Name of the customer who made the purchase" + }, + { + "data_type": "int", + "column_name": "Total with out GST", + "column_description": "Total amount of the invoice without including GST" + }, + { + "data_type": "int", + "column_name": "Total", + "column_description": "Total amount of the invoice including GST" + }, + { + "data_type": "text", + "column_name": "Status", + "column_description": "Current status of the invoice" + }, + { + "data_type": "text", + "column_name": "Salesperson Name", + "column_description": "Name of the salesperson who handled the sale" + }, + { + "data_type": "text", + "column_name": "Account Type", + "column_description": "Type of account associated with the invoice" + }, + { + "data_type": "text", + "column_name": "Item Category", + "column_description": "Category of the item purchased" + } + ] + expected_output = ( + "CREATE TABLE table1 (\n" + " \"Invoice Number\" text, --Unique identifier for each invoice\n" + " \"Invoice Date\" text, --Date when the invoice was issued\n" + " \"Sales Order#\" text, --Sales order number associated with the invoice\n" + " \"Customer Name\" text, --Name of the customer who made the purchase\n" + " \"Total with out GST\" integer, --Total amount of the invoice without including GST\n" + " Total integer, --Total amount of the invoice including GST\n" + " Status text, --Current status of the invoice\n" + " \"Salesperson Name\" text, --Name of the salesperson who handled the sale\n" + " \"Account Type\" text, --Type of account associated with the invoice\n" + " \"Item Category\" text --Category of the item purchased\n" + ");\n" + ) + self.assertEqual(mk_create_table_ddl(table_name, columns), expected_output) + class TestMkCreateDDL(unittest.TestCase): From 3e0a275eb1ca2c37e24233f0196eb1ea7cfe1ea8 Mon Sep 17 00:00:00 2001 From: jp Date: Mon, 5 Aug 2024 09:50:51 +0800 Subject: [PATCH 2/4] lint and add sqlglotrs version --- defog_utils/utils_db.py | 10 ++- defog_utils/utils_sql.py | 2 + requirements.txt | 2 +- tests/test_utils_db.py | 170 +++++++++++++++++++++++++++++---------- tests/test_utils_sql.py | 95 +++++++++++++++------- 5 files changed, 204 insertions(+), 75 deletions(-) diff --git a/defog_utils/utils_db.py b/defog_utils/utils_db.py index 4d8587f..f7b4173 100644 --- a/defog_utils/utils_db.py +++ b/defog_utils/utils_db.py @@ -234,7 +234,7 @@ def mk_delete_ddl(md: Dict[str, Any]) -> str: # check if the contents is a dictionary of tables or a list of tables is_schema = isinstance(contents, Dict) break - + if is_schema: md_delete = "" for schema, tables in md.items(): @@ -288,7 +288,9 @@ def fix_md(md: Dict[str, List[Dict[str, str]]]) -> Dict[str, List[Dict[str, str] return md_new -def test_valid_md_sql(sql: str, md: dict, creds: Dict = None, conn = None, verbose: bool = False): +def test_valid_md_sql( + sql: str, md: dict, creds: Dict = None, conn=None, verbose: bool = False +): """ Test custom metadata and a sql query This will perform the following steps: @@ -299,7 +301,7 @@ def test_valid_md_sql(sql: str, md: dict, creds: Dict = None, conn = None, verbo If provided with the variable `conn`, this reuses the same database connection to avoid creating a new connection for each query. Otherwise it will connect via psycopg2 using the credentials provided (note that creds should set db_name) - This will not manage `conn` in any way (eg closing `conn`) - it is left to + This will not manage `conn` in any way (eg closing `conn`) - it is left to the caller to manage the connection. Returns tuple of (sql_valid, md_valid, err_message) """ @@ -546,7 +548,7 @@ def parse_md(md_str: str) -> Dict[str, List[Dict[str, str]]]: def get_table_names(md: str) -> List[str]: """ Given a string of metadata formatted as a series of - CREATE TABLE statements, return a list of table names in the same order as + CREATE TABLE statements, return a list of table names in the same order as they appear in the metadata. """ table_names = [] diff --git a/defog_utils/utils_sql.py b/defog_utils/utils_sql.py index f1c2645..0bc09bc 100644 --- a/defog_utils/utils_sql.py +++ b/defog_utils/utils_sql.py @@ -520,9 +520,11 @@ def is_date_or_time_str(s: str) -> bool: m = re.match(date_or_time_pattern, s) return bool(m) + def has_month_name(s: str) -> bool: return bool(re.search(month_name_pattern, s, re.IGNORECASE)) + def has_date_in_name(s: str) -> bool: return bool(re.search(r"(year|quarter|month|week|day)", s)) diff --git a/requirements.txt b/requirements.txt index 5408e58..085dfb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ numpy psycopg2-binary -sqlglot +sqlglotrs>=0.2.8 sqlparse \ No newline at end of file diff --git a/tests/test_utils_db.py b/tests/test_utils_db.py index 9aacee5..7de2c0b 100644 --- a/tests/test_utils_db.py +++ b/tests/test_utils_db.py @@ -51,79 +51,78 @@ def test_mk_create_table_ddl(self): ");\n" ) self.assertEqual(mk_create_table_ddl(table_name, columns), expected_output) - + def test_mk_create_table_ddl_spaces(self): table_name = "table1" columns = [ { "data_type": "text", "column_name": "Invoice Number", - "column_description": "Unique identifier for each invoice" + "column_description": "Unique identifier for each invoice", }, { "data_type": "text", "column_name": "Invoice Date", - "column_description": "Date when the invoice was issued" + "column_description": "Date when the invoice was issued", }, { "data_type": "text", "column_name": "Sales Order#", - "column_description": "Sales order number associated with the invoice" + "column_description": "Sales order number associated with the invoice", }, { "data_type": "text", "column_name": "Customer Name", - "column_description": "Name of the customer who made the purchase" + "column_description": "Name of the customer who made the purchase", }, { "data_type": "int", "column_name": "Total with out GST", - "column_description": "Total amount of the invoice without including GST" + "column_description": "Total amount of the invoice without including GST", }, { "data_type": "int", "column_name": "Total", - "column_description": "Total amount of the invoice including GST" + "column_description": "Total amount of the invoice including GST", }, { "data_type": "text", "column_name": "Status", - "column_description": "Current status of the invoice" + "column_description": "Current status of the invoice", }, { "data_type": "text", "column_name": "Salesperson Name", - "column_description": "Name of the salesperson who handled the sale" + "column_description": "Name of the salesperson who handled the sale", }, { "data_type": "text", "column_name": "Account Type", - "column_description": "Type of account associated with the invoice" + "column_description": "Type of account associated with the invoice", }, { "data_type": "text", "column_name": "Item Category", - "column_description": "Category of the item purchased" - } + "column_description": "Category of the item purchased", + }, ] expected_output = ( "CREATE TABLE table1 (\n" - " \"Invoice Number\" text, --Unique identifier for each invoice\n" - " \"Invoice Date\" text, --Date when the invoice was issued\n" - " \"Sales Order#\" text, --Sales order number associated with the invoice\n" - " \"Customer Name\" text, --Name of the customer who made the purchase\n" - " \"Total with out GST\" integer, --Total amount of the invoice without including GST\n" + ' "Invoice Number" text, --Unique identifier for each invoice\n' + ' "Invoice Date" text, --Date when the invoice was issued\n' + ' "Sales Order#" text, --Sales order number associated with the invoice\n' + ' "Customer Name" text, --Name of the customer who made the purchase\n' + ' "Total with out GST" integer, --Total amount of the invoice without including GST\n' " Total integer, --Total amount of the invoice including GST\n" " Status text, --Current status of the invoice\n" - " \"Salesperson Name\" text, --Name of the salesperson who handled the sale\n" - " \"Account Type\" text, --Type of account associated with the invoice\n" - " \"Item Category\" text --Category of the item purchased\n" + ' "Salesperson Name" text, --Name of the salesperson who handled the sale\n' + ' "Account Type" text, --Type of account associated with the invoice\n' + ' "Item Category" text --Category of the item purchased\n' ");\n" ) self.assertEqual(mk_create_table_ddl(table_name, columns), expected_output) - class TestMkCreateDDL(unittest.TestCase): def test_mk_create_ddl(self): md = { @@ -469,24 +468,64 @@ def test_parse_md_2(self): );""" expected = { "acct_trx": [ - {"column_name": "trx_units", "data_type": "numeric(10,2)", "column_description": ""}, - {"column_name": "asset_id", "data_type": "integer", "column_description": ""}, - {"column_name": "trx_amount", "data_type": "numeric(10,2)", "column_description": ""}, - {"column_name": "details", "data_type": "varchar(500)", "column_description": ""}, - {"column_name": "id", "data_type": "integer", "column_description": "Primary key for acct_trx table, joinable with other tables"}, - {"column_name": "settle_date", "data_type": "date", "column_description": "Date transaction settled"}, - {"column_name": "symbol", "data_type": "varchar(10)", "column_description": ""}, + { + "column_name": "trx_units", + "data_type": "numeric(10,2)", + "column_description": "", + }, + { + "column_name": "asset_id", + "data_type": "integer", + "column_description": "", + }, + { + "column_name": "trx_amount", + "data_type": "numeric(10,2)", + "column_description": "", + }, + { + "column_name": "details", + "data_type": "varchar(500)", + "column_description": "", + }, + { + "column_name": "id", + "data_type": "integer", + "column_description": "Primary key for acct_trx table, joinable with other tables", + }, + { + "column_name": "settle_date", + "data_type": "date", + "column_description": "Date transaction settled", + }, + { + "column_name": "symbol", + "data_type": "varchar(10)", + "column_description": "", + }, ], "acct_perf": [ - {"column_name": "ytd_return", "data_type": "numeric(5,2)", "column_description": ""}, - {"column_name": "acct_snapshot_date", "data_type": "text", "column_description": "format: yyyy-mm-dd"}, - {"column_name": "account_id", "data_type": "integer", "column_description": "Primary key, foreign key to cust_acct table"}, + { + "column_name": "ytd_return", + "data_type": "numeric(5,2)", + "column_description": "", + }, + { + "column_name": "acct_snapshot_date", + "data_type": "text", + "column_description": "format: yyyy-mm-dd", + }, + { + "column_name": "account_id", + "data_type": "integer", + "column_description": "Primary key, foreign key to cust_acct table", + }, ], } md = parse_md(md_str) print(md) self.assertDictEqual(md, expected) - + def test_parse_md_3(self): md_str = """CREATE TABLE acct_trx ( trx_units numeric(10, 2), @@ -504,18 +543,58 @@ def test_parse_md_3(self): );""" expected = { "acct_trx": [ - {"column_name": "trx_units", "data_type": "numeric(10, 2)", "column_description": ""}, - {"column_name": "asset_id", "data_type": "integer", "column_description": ""}, - {"column_name": "trx_amount", "data_type": "numeric(10, 2)", "column_description": ""}, - {"column_name": "details", "data_type": "varchar(500)", "column_description": ""}, - {"column_name": "id", "data_type": "integer", "column_description": "Primary key for acct_trx table, joinable with other tables"}, - {"column_name": "settle_date", "data_type": "date", "column_description": "Date transaction settled"}, - {"column_name": "symbol", "data_type": "varchar(10)", "column_description": ""}, + { + "column_name": "trx_units", + "data_type": "numeric(10, 2)", + "column_description": "", + }, + { + "column_name": "asset_id", + "data_type": "integer", + "column_description": "", + }, + { + "column_name": "trx_amount", + "data_type": "numeric(10, 2)", + "column_description": "", + }, + { + "column_name": "details", + "data_type": "varchar(500)", + "column_description": "", + }, + { + "column_name": "id", + "data_type": "integer", + "column_description": "Primary key for acct_trx table, joinable with other tables", + }, + { + "column_name": "settle_date", + "data_type": "date", + "column_description": "Date transaction settled", + }, + { + "column_name": "symbol", + "data_type": "varchar(10)", + "column_description": "", + }, ], "acct_perf": [ - {"column_name": "ytd_return", "data_type": "numeric(5, 2)", "column_description": ""}, - {"column_name": "acct_snapshot_date", "data_type": "text", "column_description": "format: yyyy-mm-dd"}, - {"column_name": "account_id", "data_type": "integer", "column_description": "Primary key, foreign key to cust_acct table"}, + { + "column_name": "ytd_return", + "data_type": "numeric(5, 2)", + "column_description": "", + }, + { + "column_name": "acct_snapshot_date", + "data_type": "text", + "column_description": "format: yyyy-mm-dd", + }, + { + "column_name": "account_id", + "data_type": "integer", + "column_description": "Primary key, foreign key to cust_acct table", + }, ], } md = parse_md(md_str) @@ -695,7 +774,12 @@ def test_generate_aliases_with_reserved_keywords(self): self.assertEqual(result, expected_result) def test_generate_aliases_with_dots_and_underscores(self): - table_names = ["db.schema.table1", "db.schema.table2", "db.schema.table3", "_uncompressed___long_name_"] + table_names = [ + "db.schema.table1", + "db.schema.table2", + "db.schema.table3", + "_uncompressed___long_name_", + ] result = generate_aliases(table_names) print(result) expected_result = "-- db.schema.table1 AS t1\n-- db.schema.table2 AS t2\n-- db.schema.table3 AS t3\n-- _uncompressed___long_name_ AS uln\n" diff --git a/tests/test_utils_sql.py b/tests/test_utils_sql.py index 7b1360b..d120745 100644 --- a/tests/test_utils_sql.py +++ b/tests/test_utils_sql.py @@ -57,7 +57,7 @@ def test_join_left(self): sql = "SELECT * FROM table1 t1 JOIN table2 t2 ON t1.id = t2.id" features = get_sql_features(sql, self.md_cols, self.md_tables) self.assertFalse(features.join_left) - + def test_addition(self): sql = "SELECT column1 + column2 FROM table" features = get_sql_features(sql, self.md_cols, self.md_tables) @@ -105,7 +105,7 @@ def test_agg_min_max(self): features = get_sql_features(sql, self.md_cols, self.md_tables) self.assertFalse(features.agg_min) self.assertTrue(features.agg_max) - + def test_nested_agg(self): sql = "SELECT COUNT(col), MIN(col), SUM(col2-col3) FROM table" features = get_sql_features(sql, self.md_cols, self.md_tables) @@ -181,7 +181,7 @@ def test_has_date_columns(self): self.assertTrue(features.has_date_text) self.assertFalse(features.has_date_int) self.assertTrue(features.date_literal) - + def test_date_literal(self): sql = "SELECT column1 FROM table WHERE column2 = '2023-01-01'" features = get_sql_features(sql, self.md_cols, self.md_tables) @@ -203,7 +203,9 @@ def test_date_trunc(self): sql = "SELECT DATE_TRUNC('day', column) FROM table" features = get_sql_features(sql, self.md_cols, self.md_tables) self.assertTrue(features.date_trunc) - features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="postgres") + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="postgres" + ) self.assertTrue(features.date_trunc) def test_strftime(self): @@ -211,7 +213,9 @@ def test_strftime(self): features = get_sql_features(sql, self.md_cols, self.md_tables) self.assertTrue(features.strftime) self.assertFalse(features.date_time_format) - features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="postgres") + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="postgres" + ) self.assertTrue(features.strftime) self.assertFalse(features.date_time_format) @@ -232,11 +236,13 @@ def test_date_comparison(self): "date_type_text": set(), } for sql in [sql_left, sql_right]: - features_with_empty_col_info = \ - get_sql_features(sql, self.md_cols, self.md_tables, self.empty_extra_column_info) + features_with_empty_col_info = get_sql_features( + sql, self.md_cols, self.md_tables, self.empty_extra_column_info + ) self.assertFalse(features_with_empty_col_info.date_comparison) - features_with_date_col_info = \ - get_sql_features(sql, self.md_cols, self.md_tables, extra_column_info) + features_with_date_col_info = get_sql_features( + sql, self.md_cols, self.md_tables, extra_column_info + ) self.assertTrue(features_with_date_col_info.date_comparison) def test_date_sub_date(self): @@ -257,7 +263,9 @@ def test_date_sub_date(self): "date_type_text": set(), } # date - date - features = get_sql_features(sql, self.md_cols, self.md_tables, extra_column_info_both) + features = get_sql_features( + sql, self.md_cols, self.md_tables, extra_column_info_both + ) self.assertTrue(features.date_sub_date) self.assertFalse(features.date_sub) # x - date or date - x @@ -266,7 +274,9 @@ def test_date_sub_date(self): self.assertFalse(features.date_sub_date) self.assertTrue(features.date_sub) # x - x - features = get_sql_features(sql, self.md_cols, self.md_tables, self.empty_extra_column_info) + features = get_sql_features( + sql, self.md_cols, self.md_tables, self.empty_extra_column_info + ) self.assertFalse(features.date_sub_date) self.assertFalse(features.date_sub) @@ -294,46 +304,70 @@ def test_interval(self): def test_date_time_type_conversion(self): sql = "SELECT CAST(column AS TIMESTAMP) FROM table" - features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="postgres") + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="postgres" + ) self.assertTrue(features.date_time_type_conversion) sql = "SELECT CAST(column AS DATE) FROM table" - features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="postgres") + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="postgres" + ) self.assertTrue(features.date_time_type_conversion) sql = "SELECT TO_DATE(column, 'YYYY-MM-DD') FROM table" - features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="postgres") + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="postgres" + ) self.assertTrue(features.date_time_type_conversion) sql = "SELECT TO_TIMESTAMP(column, 'YYYY-MM-DD') FROM table" - features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="postgres") + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="postgres" + ) self.assertTrue(features.date_time_type_conversion) sql = "SELECT column::TIMESTAMP FROM table" - features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="postgres") + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="postgres" + ) self.assertTrue(features.date_time_type_conversion) sql = "SELECT column::DATE FROM table" - features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="postgres") + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="postgres" + ) self.assertTrue(features.date_time_type_conversion) sql = "SELECT DATE(column) FROM table" - features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="postgres") + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="postgres" + ) self.assertTrue(features.date_time_type_conversion) def test_date_time_format(self): sql = "SELECT TO_CHAR(column, 'YYYY-MM-DD') FROM table" - features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="postgres") + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="postgres" + ) self.assertTrue(features.date_time_format) self.assertTrue(features.strftime) sql = "SELECT TO_DATE(column, 'YYYY-MM-DD') FROM table" - features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="postgres") + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="postgres" + ) self.assertFalse(features.date_time_format) self.assertFalse(features.strftime) def test_generate_timeseries(self): sql = "SELECT generate_series(1, 10)" - features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="postgres") + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="postgres" + ) self.assertTrue(features.generate_timeseries) sql = "SELECT generate_series('2023-01-01'::DATE, '2023-01-10'::DATE, '1 day')" - features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="postgres") + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="postgres" + ) self.assertTrue(features.generate_timeseries) sql = "SELECT generate_series('2023-01-01'::TIMESTAMP, '2023-01-10'::TIMESTAMP, '1 day')" - features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="postgres") + features = get_sql_features( + sql, self.md_cols, self.md_tables, dialect="postgres" + ) self.assertTrue(features.generate_timeseries) def test_string_concat(self): @@ -1024,8 +1058,8 @@ class TestFixComma(unittest.TestCase): def test_fix_comma_1(self): cols = [ " CUSTOMER_EMAIL VARCHAR,", - " CUSTOMER_PHONE VARCHAR(200) --Phone number of the customer", # add comma - " value numeric(10,2),", # remove trailing comma + " CUSTOMER_PHONE VARCHAR(200) --Phone number of the customer", # add comma + " value numeric(10,2),", # remove trailing comma ] expected = [ " CUSTOMER_EMAIL VARCHAR,", @@ -1086,7 +1120,7 @@ def test_shuffle_table_metadata_seed_1(self): md_shuffled = shuffle_table_metadata(input_md_str, 42) self.maxDiff = None self.assertEqual(md_shuffled, expected_md_shuffled) - + def test_shuffle_table_metadata_seed_2(self): input_md_str = """CREATE TABLE branch_info ( branch_open_date date, --Date branch opened @@ -1163,6 +1197,7 @@ def test_shuffle_table_metadata_seed_4(self): print(md_shuffled) self.assertEqual(md_shuffled, expected_md_shuffled) + class TestFunctions(unittest.TestCase): def test_is_date_or_time_str(self): s1 = "2022-03-19" @@ -1254,7 +1289,13 @@ def test_sql_2(self): def test_sql_3(self): sql = "SELECT CAST((SELECT COUNT(aw.artwork_id) FROM artwork aw WHERE aw.year_created = 1888 AND aw.description IS NULL) AS FLOAT) / NULLIF((SELECT COUNT(at.artist_id) FROM artists AT WHERE at.nationality ilike '%French%'), 0) AS ratio;" - new_alias_map = {'exhibit_artworks': 'ea', 'exhibitions': 'e', 'collaborations': 'c', 'artwork': 'a', 'artists': 'ar'} + new_alias_map = { + "exhibit_artworks": "ea", + "exhibitions": "e", + "collaborations": "c", + "artwork": "a", + "artists": "ar", + } expected = "SELECT CAST((SELECT COUNT(a.artwork_id) FROM artwork AS a WHERE a.year_created = 1888 AND a.description IS NULL) AS DOUBLE PRECISION) / NULLIF((SELECT COUNT(ar.artist_id) FROM artists AS ar WHERE ar.nationality ILIKE '%French%'), 0) AS ratio" result = replace_alias(sql, new_alias_map) print(result) From 59f03f089ac2b7366316d35108d66f3cf16ccf49 Mon Sep 17 00:00:00 2001 From: jp Date: Mon, 5 Aug 2024 09:53:11 +0800 Subject: [PATCH 3/4] update requirements.txt --- requirements.txt | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 085dfb8..9881e74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ numpy -psycopg2-binary -sqlglotrs>=0.2.8 -sqlparse \ No newline at end of file +psycopg2-binary==2.9.9 +sqlglot==25.8.1 +sqlglotrs==0.2.8 +sqlparse==0.5.1 \ No newline at end of file From c32cf240d2ddcecd599b4a0492461b2a750bea53 Mon Sep 17 00:00:00 2001 From: jp Date: Mon, 5 Aug 2024 10:00:07 +0800 Subject: [PATCH 4/4] Fix TO_DATE feature --- defog_utils/utils_sql.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/defog_utils/utils_sql.py b/defog_utils/utils_sql.py index 0bc09bc..47797e5 100644 --- a/defog_utils/utils_sql.py +++ b/defog_utils/utils_sql.py @@ -429,6 +429,8 @@ def get_sql_features( features.date_sub = True elif isinstance(node, exp.DateTrunc) or isinstance(node, exp.TimestampTrunc): features.date_trunc = True + elif isinstance(node, exp.StrToDate): + features.date_time_type_conversion = True elif isinstance(node, exp.StrToTime): features.date_time_type_conversion = True elif isinstance(node, exp.Extract):