Skip to content

Commit 0e8fc51

Browse files
authored
feat(cli): cache sql parsing intermediates (datahub-project#10399)
1 parent 1dae37a commit 0e8fc51

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -406,10 +406,11 @@ def _schema_aware_fuzzy_column_resolve(
406406
return default_col_name
407407

408408
# Optimize the statement + qualify column references.
409-
logger.debug(
410-
"Prior to column qualification sql %s",
411-
statement.sql(pretty=True, dialect=dialect),
412-
)
409+
if logger.isEnabledFor(logging.DEBUG):
410+
logger.debug(
411+
"Prior to column qualification sql %s",
412+
statement.sql(pretty=True, dialect=dialect),
413+
)
413414
try:
414415
# Second time running qualify, this time with:
415416
# - the select instead of the full outer statement
@@ -434,7 +435,8 @@ def _schema_aware_fuzzy_column_resolve(
434435
raise SqlUnderstandingError(
435436
f"sqlglot failed to map columns to their source tables; likely missing/outdated table schema info: {e}"
436437
) from e
437-
logger.debug("Qualified sql %s", statement.sql(pretty=True, dialect=dialect))
438+
if logger.isEnabledFor(logging.DEBUG):
439+
logger.debug("Qualified sql %s", statement.sql(pretty=True, dialect=dialect))
438440

439441
# Handle the create DDL case.
440442
if is_create_ddl:
@@ -805,7 +807,7 @@ def _sqlglot_lineage_inner(
805807
logger.debug("Parsing lineage from sql statement: %s", sql)
806808
statement = parse_statement(sql, dialect=dialect)
807809

808-
original_statement = statement.copy()
810+
original_statement, statement = statement, statement.copy()
809811
# logger.debug(
810812
# "Formatted sql statement: %s",
811813
# original_statement.sql(pretty=True, dialect=dialect),

metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import hashlib
23
import logging
34
from typing import Dict, Iterable, Optional, Tuple, Union
@@ -7,6 +8,7 @@
78

89
logger = logging.getLogger(__name__)
910
DialectOrStr = Union[sqlglot.Dialect, str]
11+
SQL_PARSE_CACHE_SIZE = 1000
1012

1113

1214
def _get_dialect_str(platform: str) -> str:
@@ -55,7 +57,8 @@ def is_dialect_instance(
5557
return False
5658

5759

58-
def parse_statement(
60+
@functools.lru_cache(maxsize=SQL_PARSE_CACHE_SIZE)
61+
def _parse_statement(
5962
sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect
6063
) -> sqlglot.Expression:
6164
statement: sqlglot.Expression = sqlglot.maybe_parse(
@@ -64,6 +67,16 @@ def parse_statement(
6467
return statement
6568

6669

70+
def parse_statement(
71+
sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect
72+
) -> sqlglot.Expression:
73+
# Parsing is significantly more expensive than copying the expression.
74+
# Because the expressions are mutable, we don't want to allow the caller
75+
# to modify the parsed expression that sits in the cache. We keep
76+
# the cached versions pristine by returning a copy on each call.
77+
return _parse_statement(sql, dialect).copy()
78+
79+
6780
def parse_statements_and_pick(sql: str, platform: DialectOrStr) -> sqlglot.Expression:
6881
dialect = get_dialect(platform)
6982
statements = [
@@ -277,4 +290,5 @@ def replace_cte_refs(node: sqlglot.exp.Expression) -> sqlglot.exp.Expression:
277290
else:
278291
return node
279292

293+
statement = statement.copy()
280294
return statement.transform(replace_cte_refs, copy=False)

0 commit comments

Comments
 (0)