Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
Make dbt data diffs concurrent (#776)
Browse files Browse the repository at this point in the history
* v0 of concurrency

* concurrent logging

* remove todo

* remove todo

* better var name

* add node name to logger

* format string logs

* add optional logger param

* avoid extra threads

* use thread pools

* not multithreaded at the connection level anymore

* show errors as they happen

* show full stacktrace on error

* rearrange trace

* more logs for debugging

* update for threads mocking

* clear log params

* remove extra space

* remove long traceback

* Ensure log_message is optional

Co-authored-by: Dan Lawin <[email protected]>

* map threaded result to proper model id

* explicit type and optional

* rm submodules again

---------

Co-authored-by: Sung Won Chung <[email protected]>
Co-authored-by: Dan Lawin <[email protected]>
  • Loading branch information
3 people authored Dec 5, 2023
1 parent 29b48b0 commit b3d4223
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 64 deletions.
15 changes: 10 additions & 5 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ def name(self):
def compile(self, sql_ast):
return self.dialect.compile(Compiler(self), sql_ast)

def query(self, sql_ast: Union[Expr, Generator], res_type: type = None):
def query(self, sql_ast: Union[Expr, Generator], res_type: type = None, log_message: Optional[str] = None):
"""Query the given SQL code/AST, and attempt to convert the result to type 'res_type'
If given a generator, it will execute all the yielded sql queries with the same thread and cursor.
Expand All @@ -956,7 +956,10 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None):
if sql_code is SKIP:
return SKIP

logger.debug("Running SQL (%s):\n%s", self.name, sql_code)
if log_message:
logger.debug("Running SQL (%s): %s \n%s", self.name, log_message, sql_code)
else:
logger.debug("Running SQL (%s):\n%s", self.name, sql_code)

if self._interactive and isinstance(sql_ast, Select):
explained_sql = self.compile(Explain(sql_ast))
Expand Down Expand Up @@ -1022,7 +1025,7 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
Note: This method exists instead of select_table_schema(), just because not all databases support
accessing the schema using a SQL query.
"""
rows = self.query(self.select_table_schema(path), list)
rows = self.query(self.select_table_schema(path), list, log_message=path)
if not rows:
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")

Expand All @@ -1044,7 +1047,7 @@ def query_table_unique_columns(self, path: DbPath) -> List[str]:
"""Query the table for its unique columns for table in 'path', and return {column}"""
if not self.SUPPORTS_UNIQUE_CONSTAINT:
raise NotImplementedError("This database doesn't support 'unique' constraints")
res = self.query(self.select_table_unique_columns(path), List[str])
res = self.query(self.select_table_unique_columns(path), List[str], log_message=path)
return list(res)

def _process_table_schema(
Expand Down Expand Up @@ -1086,7 +1089,9 @@ def _refine_coltypes(
fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns]

samples_by_row = self.query(
table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list
table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size),
list,
log_message=table_path,
)
if not samples_by_row:
raise ValueError(f"Table {table_path} is empty.")
Expand Down
43 changes: 27 additions & 16 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pydantic
import rich
from rich.prompt import Prompt
from concurrent.futures import ThreadPoolExecutor, as_completed

from data_diff.errors import (
DataDiffCustomSchemaNoConfigError,
Expand Down Expand Up @@ -80,7 +81,6 @@ def dbt_diff(
production_schema_flag: Optional[str] = None,
) -> None:
print_version_info()
diff_threads = []
set_entrypoint_name(os.getenv("DATAFOLD_TRIGGERED_BY", "CLI-dbt"))
dbt_parser = DbtParser(profiles_dir_override, project_dir_override, state)
models = dbt_parser.get_models(dbt_selection)
Expand Down Expand Up @@ -112,7 +112,11 @@ def dbt_diff(
else:
dbt_parser.set_connection()

with log_status_handler.status if log_status_handler else nullcontext():
futures = {}

with log_status_handler.status if log_status_handler else nullcontext(), ThreadPoolExecutor(
max_workers=dbt_parser.threads
) as executor:
for model in models:
if log_status_handler:
log_status_handler.set_prefix(f"Diffing {model.alias} \n")
Expand Down Expand Up @@ -140,12 +144,12 @@ def dbt_diff(

if diff_vars.primary_keys:
if is_cloud:
diff_thread = run_as_daemon(
future = executor.submit(
_cloud_diff, diff_vars, config.datasource_id, api, org_meta, log_status_handler
)
diff_threads.append(diff_thread)
else:
_local_diff(diff_vars, json_output)
future = executor.submit(_local_diff, diff_vars, json_output, log_status_handler)
futures[future] = model
else:
if json_output:
print(
Expand All @@ -165,10 +169,12 @@ def dbt_diff(
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n"
)

# wait for all threads
if diff_threads:
for thread in diff_threads:
thread.join()
for future in as_completed(futures):
model = futures[future]
try:
future.result() # if error occurred, it will be raised here
except Exception as e:
logger.error(f"An error occurred during the execution of a diff task: {model.unique_id} - {e}")

_extension_notification()

Expand Down Expand Up @@ -265,15 +271,17 @@ def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str,
return prod_database, prod_schema, prod_alias


def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None:
def _local_diff(
diff_vars: TDiffVars, json_output: bool = False, log_status_handler: Optional[LogStatusHandler] = None
) -> None:
if log_status_handler:
log_status_handler.diff_started(diff_vars.dev_path[-1])
dev_qualified_str = ".".join(diff_vars.dev_path)
prod_qualified_str = ".".join(diff_vars.prod_path)
diff_output_str = _diff_output_base(dev_qualified_str, prod_qualified_str)

table1 = connect_to_table(
diff_vars.connection, prod_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads
)
table2 = connect_to_table(diff_vars.connection, dev_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads)
table1 = connect_to_table(diff_vars.connection, prod_qualified_str, tuple(diff_vars.primary_keys))
table2 = connect_to_table(diff_vars.connection, dev_qualified_str, tuple(diff_vars.primary_keys))

try:
table1_columns = table1.get_schema()
Expand Down Expand Up @@ -373,6 +381,9 @@ def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None:
diff_output_str += no_differences_template()
rich.print(diff_output_str)

if log_status_handler:
log_status_handler.diff_finished(diff_vars.dev_path[-1])


def _initialize_api() -> Optional[DatafoldAPI]:
datafold_host = os.environ.get("DATAFOLD_HOST")
Expand Down Expand Up @@ -406,7 +417,7 @@ def _cloud_diff(
log_status_handler: Optional[LogStatusHandler] = None,
) -> None:
if log_status_handler:
log_status_handler.cloud_diff_started(diff_vars.dev_path[-1])
log_status_handler.diff_started(diff_vars.dev_path[-1])
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
payload = TCloudApiDataDiff(
data_source1_id=datasource_id,
Expand Down Expand Up @@ -476,7 +487,7 @@ def _cloud_diff(
rich.print(diff_output_str)

if log_status_handler:
log_status_handler.cloud_diff_finished(diff_vars.dev_path[-1])
log_status_handler.diff_finished(diff_vars.dev_path[-1])
except BaseException as ex: # Catch KeyboardInterrupt too
error = ex
finally:
Expand Down
6 changes: 3 additions & 3 deletions data_diff/dbt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,17 +446,17 @@ def get_pk_from_model(self, node, unique_columns: dict, pk_tag: str) -> List[str

from_meta = [name for name, params in node.columns.items() if pk_tag in params.meta] or None
if from_meta:
logger.debug("Found PKs via META: " + str(from_meta))
logger.debug(f"Found PKs via META [{node.name}]: " + str(from_meta))
return from_meta

from_tags = [name for name, params in node.columns.items() if pk_tag in params.tags] or None
if from_tags:
logger.debug("Found PKs via Tags: " + str(from_tags))
logger.debug(f"Found PKs via Tags [{node.name}]: " + str(from_tags))
return from_tags
if node.unique_id in unique_columns:
from_uniq = unique_columns.get(node.unique_id)
if from_uniq is not None:
logger.debug("Found PKs via Uniqueness tests: " + str(from_uniq))
logger.debug(f"Found PKs via Uniqueness tests [{node.name}]: {str(from_uniq)}")
return list(from_uniq)

except (KeyError, IndexError, TypeError) as e:
Expand Down
61 changes: 42 additions & 19 deletions data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _diff_tables_root(self, table1: TableSegment, table2: TableSegment, info_tre
yield from self._diff_segments(None, table1, table2, info_tree, None)
else:
yield from self._bisect_and_diff_tables(table1, table2, info_tree)
logger.info("Diffing complete")
logger.info(f"Diffing complete: {table1.table_path} <> {table2.table_path}")
if self.materialize_to_table:
logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table))

Expand Down Expand Up @@ -193,8 +193,8 @@ def _diff_segments(
partial(self._collect_stats, 1, table1, info_tree),
partial(self._collect_stats, 2, table2, info_tree),
partial(self._test_null_keys, table1, table2),
partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols),
partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols),
partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols, table1, table2),
partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols, table1, table2),
partial(
self._materialize_diff,
db,
Expand All @@ -205,8 +205,8 @@ def _diff_segments(
else None,
):
assert len(a_cols) == len(b_cols)
logger.debug("Querying for different rows")
diff = db.query(diff_rows, list)
logger.debug(f"Querying for different rows: {table1.table_path}")
diff = db.query(diff_rows, list, log_message=table1.table_path)
info_tree.info.set_diff(diff, schema=tuple(diff_rows.schema.items()))
for is_xa, is_xb, *x in diff:
if is_xa and is_xb:
Expand All @@ -227,7 +227,7 @@ def _diff_segments(
yield "+", tuple(b_row)

def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment):
logger.debug("Testing for duplicate keys")
logger.debug(f"Testing for duplicate keys: {table1.table_path} <> {table2.table_path}")

# Test duplicate keys
for ts in [table1, table2]:
Expand All @@ -240,24 +240,24 @@ def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment):

unvalidated = list(set(key_columns) - set(unique))
if unvalidated:
logger.info(f"Validating that the are no duplicate keys in columns: {unvalidated}")
logger.info(f"Validating that the are no duplicate keys in columns: {unvalidated} for {ts.table_path}")
# Validate that there are no duplicate keys
self.stats["validated_unique_keys"] = self.stats.get("validated_unique_keys", []) + [unvalidated]
q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True))
total, total_distinct = ts.database.query(q, tuple)
total, total_distinct = ts.database.query(q, tuple, log_message=ts.table_path)
if total != total_distinct:
raise ValueError("Duplicate primary keys")

def _test_null_keys(self, table1, table2):
logger.debug("Testing for null keys")
logger.debug(f"Testing for null keys: {table1.table_path} <> {table2.table_path}")

# Test null keys
for ts in [table1, table2]:
t = ts.make_select()
key_columns = ts.key_columns

q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns))
nulls = ts.database.query(q, list)
nulls = ts.database.query(q, list, log_message=ts.table_path)
if nulls:
if self.skip_null_keys:
logger.warning(
Expand All @@ -267,7 +267,7 @@ def _test_null_keys(self, table1, table2):
raise ValueError(f"NULL values in one or more primary keys of {ts.table_path}")

def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
logger.debug(f"Collecting stats for table #{i}")
logger.debug(f"Collecting stats for table #{i}: {table_seg.table_path}")
db = table_seg.database

# Metrics
Expand All @@ -288,7 +288,7 @@ def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
)
col_exprs["count"] = Count()

res = db.query(table_seg.make_select().select(**col_exprs), tuple)
res = db.query(table_seg.make_select().select(**col_exprs), tuple, log_message=table_seg.table_path)

for col_name, value in safezip(col_exprs, res):
if value is not None:
Expand All @@ -303,7 +303,7 @@ def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
else:
self.stats[stat_name] = value

logger.debug("Done collecting stats for table #%s", i)
logger.debug("Done collecting stats for table #%s: %s", i, table_seg.table_path)

def _create_outer_join(self, table1, table2):
db = table1.database
Expand Down Expand Up @@ -334,23 +334,46 @@ def _create_outer_join(self, table1, table2):
diff_rows = all_rows.where(or_(this[c] == 1 for c in is_diff_cols))
return diff_rows, a_cols, b_cols, is_diff_cols, all_rows

def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols):
logger.debug("Counting differences per column")
is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple)
def _count_diff_per_column(
self,
db,
diff_rows,
cols,
is_diff_cols,
table1: Optional[TableSegment] = None,
table2: Optional[TableSegment] = None,
):
logger.info(type(table1))
logger.debug(f"Counting differences per column: {table1.table_path} <> {table2.table_path}")
is_diff_cols_counts = db.query(
diff_rows.select(sum_(this[c]) for c in is_diff_cols),
tuple,
log_message=f"{table1.table_path} <> {table2.table_path}",
)
diff_counts = {}
for name, count in safezip(cols, is_diff_cols_counts):
diff_counts[name] = diff_counts.get(name, 0) + (count or 0)
self.stats["diff_counts"] = diff_counts

def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols):
def _sample_and_count_exclusive(
self,
db,
diff_rows,
a_cols,
b_cols,
table1: Optional[TableSegment] = None,
table2: Optional[TableSegment] = None,
):
if isinstance(db, (Oracle, MsSQL)):
exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1))
else:
exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b)

if not self.sample_exclusive_rows:
logger.debug("Counting exclusive rows")
self.stats["exclusive_count"] = db.query(exclusive_rows_query.count(), int)
logger.debug(f"Counting exclusive rows: {table1.table_path} <> {table2.table_path}")
self.stats["exclusive_count"] = db.query(
exclusive_rows_query.count(), int, log_message=f"{table1.table_path} <> {table2.table_path}"
)
return

logger.info("Counting and sampling exclusive rows")
Expand Down
28 changes: 14 additions & 14 deletions data_diff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,31 +485,31 @@ def __init__(self):
super().__init__()
self.status = Status("")
self.prefix = ""
self.cloud_diff_status = {}
self.diff_status = {}

def emit(self, record):
log_entry = self.format(record)
if self.cloud_diff_status:
self._update_cloud_status(log_entry)
if self.diff_status:
self._update_diff_status(log_entry)
else:
self.status.update(self.prefix + log_entry)

def set_prefix(self, prefix_string):
self.prefix = prefix_string

def cloud_diff_started(self, model_name):
self.cloud_diff_status[model_name] = "[yellow]In Progress[/]"
self._update_cloud_status()
def diff_started(self, model_name):
self.diff_status[model_name] = "[yellow]In Progress[/]"
self._update_diff_status()

def cloud_diff_finished(self, model_name):
self.cloud_diff_status[model_name] = "[green]Finished [/]"
self._update_cloud_status()
def diff_finished(self, model_name):
self.diff_status[model_name] = "[green]Finished [/]"
self._update_diff_status()

def _update_cloud_status(self, log=None):
cloud_status_string = "\n"
for model_name, status in self.cloud_diff_status.items():
cloud_status_string += f"{status} {model_name}\n"
self.status.update(f"{cloud_status_string}{log or ''}")
def _update_diff_status(self, log=None):
status_string = "\n"
for model_name, status in self.diff_status.items():
status_string += f"{status} {model_name}\n"
self.status.update(f"{status_string}{log or ''}")


class UnknownMeta(type):
Expand Down
Loading

0 comments on commit b3d4223

Please sign in to comment.