Skip to content

Commit

Permalink
fix bug in replace_alias due to case mismatch
Browse files Browse the repository at this point in the history
add `test_valid_md_sql` to test md and sql in a single function call
  • Loading branch information
wongjingping committed Jun 10, 2024
1 parent 3e13fd6 commit 14a5d9b
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 6 deletions.
66 changes: 66 additions & 0 deletions defog_utils/utils_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,71 @@ def fix_md(md: Dict[str, List[Dict[str, str]]]) -> Dict[str, List[Dict[str, str]
return md_new


def test_valid_md_sql(sql: str, md: dict, creds: Dict = None, conn = None, verbose: bool = False):
"""
Test custom metadata and a sql query
This will perform the following steps:
1. Delete the tables in the metadata (to ensure that similarly named tables from previous tests are not used)
2. Create the tables in the metadata. If any errors occur with the metadata, we return early.
3. Run the sql query
4. Delete the tables created
If provided with the variable `conn`, this reuses the same database connection
to avoid creating a new connection for each query. Otherwise it will connect
via psycopg2 using the credentials provided (note that creds should set db_name)
This will not manage `conn` in any way (eg closing `conn`) - it is left to
the caller to manage the connection.
Returns tuple of (sql_valid, md_valid, err_message)
"""
try:
local_conn = False
if conn is not None and conn.closed == 0:
cur = conn.cursor()
else:
conn = psycopg2.connect(
dbname=creds["db_name"],
user=creds["user"],
password=creds["password"],
host=creds["host"],
port=creds["port"],
)
local_conn = True
cur = conn.cursor()
delete_ddl = mk_delete_ddl(md)
cur.execute(delete_ddl)
if verbose:
print(f"Deleted tables with: {delete_ddl}")
create_ddl = mk_create_ddl(md)
cur.execute(create_ddl)
if verbose:
print(f"Created tables with: {create_ddl}")
except Exception as e:
if "cur" in locals() or "cur" in globals():
cur.close()
if local_conn:
conn.close()
return False, False, e
try:
cur.execute(sql)
results = cur.fetchall()
if verbose:
for row in results:
print(row)
delete_ddl = mk_delete_ddl(md)
cur.execute(delete_ddl)
if verbose:
print(f"Deleted tables with: {delete_ddl}")
cur.close()
if local_conn:
conn.close()
return True, True, None
except Exception as e:
if "cur" in locals() or "cur" in globals():
cur.close()
if local_conn:
conn.close()
return False, True, e


def test_valid_md(
sql: str, md: dict, creds: dict, verbose: bool = False, idx: str = ""
) -> Tuple[bool, Optional[Exception]]:
Expand Down Expand Up @@ -489,6 +554,7 @@ def generate_aliases_dict(
) -> Dict[str, str]:
"""
Generate aliases for table names as a dictionary mapping of table names to aliases
Aliases should always be in lower case
"""
aliases = {}
for original_table_name in table_names:
Expand Down
17 changes: 11 additions & 6 deletions defog_utils/utils_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,11 @@ def replace_alias(
sql: str, new_alias_map: Dict[str, str], dialect: str = "postgres"
) -> str:
"""
Replaces the table aliases in the SQL query with the new aliases provided in the new_alias_map.
Replaces the table aliases in the SQL query with the new aliases provided in
the new_alias_map.
`new_alias_map` is a dict of table_name -> new_alias.
Note that aliases are always in lowercase, and will be converted to lowercase
if necessary.
"""
parsed = parse_one(sql, dialect=dialect)
existing_alias_map = {}
Expand All @@ -667,7 +670,8 @@ def replace_alias(
table_name = node.name
# save the existing alias if present
if node.alias:
existing_alias_map[node.alias] = table_name
node_alias = node.alias.lower()
existing_alias_map[node_alias] = table_name
# set the alias to the new alias if it exists in the new_alias_map
if table_name in new_alias_map:
node.set("alias", new_alias_map[table_name])
Expand All @@ -677,12 +681,13 @@ def replace_alias(
for node in parsed.walk():
if isinstance(node, exp.Column):
if node.table:
node_table = node.table.lower()
# if in existing alias map, set the table to the new alias
if node.table in existing_alias_map:
original_table_name = existing_alias_map[node.table]
if node_table in existing_alias_map:
original_table_name = existing_alias_map[node_table]
if original_table_name in new_alias_map:
node.set("table", new_alias_map[original_table_name])
# else if in new alias map, set the table to the new alias
elif node.table in new_alias_map:
node.set("table", new_alias_map[node.table])
elif node_table in new_alias_map:
node.set("table", new_alias_map[node_table])
return parsed.sql(dialect, normalize_functions="upper", comments=False)
7 changes: 7 additions & 0 deletions tests/test_utils_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,13 @@ def test_sql_2(self):
print(result)
self.assertEqual(result, expected)

def test_sql_3(self):
sql = "SELECT CAST((SELECT COUNT(aw.artwork_id) FROM artwork aw WHERE aw.year_created = 1888 AND aw.description IS NULL) AS FLOAT) / NULLIF((SELECT COUNT(at.artist_id) FROM artists AT WHERE at.nationality ilike '%French%'), 0) AS ratio;"
new_alias_map = {'exhibit_artworks': 'ea', 'exhibitions': 'e', 'collaborations': 'c', 'artwork': 'a', 'artists': 'ar'}
expected = "SELECT CAST((SELECT COUNT(a.artwork_id) FROM artwork AS a WHERE a.year_created = 1888 AND a.description IS NULL) AS DOUBLE PRECISION) / NULLIF((SELECT COUNT(ar.artist_id) FROM artists AS ar WHERE ar.nationality ILIKE '%French%'), 0) AS ratio"
result = replace_alias(sql, new_alias_map)
print(result)
self.assertEqual(result, expected)

if __name__ == "__main__":
unittest.main()

0 comments on commit 14a5d9b

Please sign in to comment.