Skip to content

Commit

Permalink
- exclude comments from formatting of reserved keywords
Browse files Browse the repository at this point in the history
- remove unparseable tildes in sql before translation
  • Loading branch information
wendy-aw committed Jun 14, 2024
1 parent 397a499 commit ab26dfe
Showing 1 changed file with 49 additions and 19 deletions.
68 changes: 49 additions & 19 deletions utils/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,23 @@ async def sem_task(index, row):
return ordered_results


def remove_unparseable_sql(sql):
"""
Remove unparseable elements from the sql
"""
if " ~* '" in sql:
sql = sql.replace(" ~* '", " ILIKE '")
if " ~ '" in sql:
sql = sql.replace(" ~ '", " LIKE '")
return sql


def sql_to_dialect(sql, db_type, dialect):
"""
Translates sql of db_type to another dialect with sqlglot.
Does not have any post-processing.
"""
sql = remove_unparseable_sql(sql)
translated = sqlglot.transpile(sql, read=db_type, write=dialect)
translated = translated[0]
return translated
Expand Down Expand Up @@ -310,6 +322,18 @@ def ddl_to_bigquery(ddl, db_type, db_name, row_idx):
"""

translated = ddl_to_dialect(ddl, db_type, "bigquery")
translated = ddl_remove_schema(translated)

# if any of reserved keywords in the non-comments section of ddl, enclose them with backticks
reserved_keywords = ["long"]
segments = re.split(r'(/\*.*?\*/)', translated)
for i, segment in enumerate(segments):
if not segment.startswith('/*'):
for keyword in reserved_keywords:
segment = re.sub(rf'(?<!")\b{keyword}\b(?!")', f"`{keyword}`", segment, flags=re.IGNORECASE)
segments[i] = segment
translated = ''.join(segments)

translated = translated.replace(")\nCREATE", ");\nCREATE")
translated = re.sub(r"SERIAL(PRIMARY KEY)?", "INT64", translated)
translated = re.sub(
Expand All @@ -318,15 +342,6 @@ def ddl_to_bigquery(ddl, db_type, db_name, row_idx):
translated,
)
translated += ";"
reserved_keywords = ["long"]
# if any of reserved keywords in the ddl, enclose them with backticks
for keyword in reserved_keywords:
translated = re.sub(
rf"\b{keyword}\b", f"`{keyword}`", translated, flags=re.IGNORECASE
)
# remove schema names
translated = ddl_remove_schema(translated)

translated_ddl = translated.replace("CREATE TABLE ", f"CREATE TABLE {db_name}.")
translated_ddl_test = translated.replace(
"CREATE TABLE ", f"CREATE TABLE test{row_idx}_{db_name}."
Expand Down Expand Up @@ -484,18 +499,22 @@ def ddl_to_mysql(ddl, db_type, db_name, row_idx):
Returns translated ddl and translated test ddl for testing.
"""
translated = ddl_to_dialect(ddl, db_type, "mysql")
translated = translated.replace(")\nCREATE", ");\nCREATE")
translated += ";"
translated = re.sub(r"VARCHAR(?!\()", "VARCHAR(255)", translated)
reserved_keywords = ["long"]
# if any of reserved keywords in the ddl, enclose them with backticks
for keyword in reserved_keywords:
translated = re.sub(
rf"\b{keyword}\b", f"`{keyword}`", translated, flags=re.IGNORECASE
)
# remove schema names
translated = ddl_remove_schema(translated)

# if any of reserved keywords in the non-comments section of ddl, enclose them with backticks
reserved_keywords = ["long"]
segments = re.split(r'(/\*.*?\*/)', translated)
for i, segment in enumerate(segments):
if not segment.startswith('/*'):
for keyword in reserved_keywords:
segment = re.sub(rf'(?<!")\b{keyword}\b(?!")', f"`{keyword}`", segment, flags=re.IGNORECASE)
segments[i] = segment

translated = ''.join(segments)

translated = translated.replace(")\nCREATE", ");\nCREATE")
translated = re.sub(r"VARCHAR(?!\()", "VARCHAR(255)", translated)
translated += ";"
translated_ddl = translated.replace("CREATE TABLE ", f"CREATE TABLE {db_name}.")
translated_ddl_test = translated.replace(
"CREATE TABLE ", f"CREATE TABLE test{row_idx}_{db_name}."
Expand Down Expand Up @@ -692,6 +711,17 @@ def ddl_to_sqlite(ddl, db_type, db_name, row_idx):
"""
translated = ddl_to_dialect(ddl, db_type, "sqlite")
translated = ddl_remove_schema(translated)

# if any of reserved keywords in the non-comments section of ddl, enclose them with backticks
reserved_keywords = ["transaction", "order"]
segments = re.split(r'(/\*.*?\*/)', translated)
for i, segment in enumerate(segments):
if not segment.startswith('/*'):
for keyword in reserved_keywords:
segment = re.sub(rf'(?<!")\b{keyword}\b(?!")', f"`{keyword}`", segment, flags=re.IGNORECASE)
segments[i] = segment
translated = ''.join(segments)

translated = translated.replace(")\nCREATE", ");\nCREATE")
translated = re.sub(r"SERIAL", "INTEGER PRIMARY KEY", translated)
translated += ";"
Expand Down

0 comments on commit ab26dfe

Please sign in to comment.