diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index c5931979..bf165461 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -931,7 +931,7 @@ def name(self): def compile(self, sql_ast): return self.dialect.compile(Compiler(self), sql_ast) - def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): + def query(self, sql_ast: Union[Expr, Generator], res_type: type = None, log_message: Optional[str] = None): """Query the given SQL code/AST, and attempt to convert the result to type 'res_type' If given a generator, it will execute all the yielded sql queries with the same thread and cursor. @@ -956,7 +956,10 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): if sql_code is SKIP: return SKIP - logger.debug("Running SQL (%s):\n%s", self.name, sql_code) + if log_message: + logger.debug("Running SQL (%s): %s \n%s", self.name, log_message, sql_code) + else: + logger.debug("Running SQL (%s):\n%s", self.name, sql_code) if self._interactive and isinstance(sql_ast, Select): explained_sql = self.compile(Explain(sql_ast)) @@ -1022,7 +1025,7 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: Note: This method exists instead of select_table_schema(), just because not all databases support accessing the schema using a SQL query. """ - rows = self.query(self.select_table_schema(path), list) + rows = self.query(self.select_table_schema(path), list, log_message=path) if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") @@ -1044,7 +1047,7 @@ def query_table_unique_columns(self, path: DbPath) -> List[str]: """Query the table for its unique columns for table in 'path', and return {column}""" if not self.SUPPORTS_UNIQUE_CONSTAINT: raise NotImplementedError("This database doesn't support 'unique' constraints") - res = self.query(self.select_table_unique_columns(path), List[str]) + res = self.query(self.select_table_unique_columns(path), List[str], log_message=path) return list(res) def _process_table_schema( @@ -1086,7 +1089,9 @@ def _refine_coltypes( fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns] samples_by_row = self.query( - table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list + table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), + list, + log_message=table_path, ) if not samples_by_row: raise ValueError(f"Table {table_path} is empty.") diff --git a/data_diff/dbt.py b/data_diff/dbt.py index bf36c4fc..ef780429 100644 --- a/data_diff/dbt.py +++ b/data_diff/dbt.py @@ -8,6 +8,7 @@ import pydantic import rich from rich.prompt import Prompt +from concurrent.futures import ThreadPoolExecutor, as_completed from data_diff.errors import ( DataDiffCustomSchemaNoConfigError, @@ -80,7 +81,6 @@ def dbt_diff( production_schema_flag: Optional[str] = None, ) -> None: print_version_info() - diff_threads = [] set_entrypoint_name(os.getenv("DATAFOLD_TRIGGERED_BY", "CLI-dbt")) dbt_parser = DbtParser(profiles_dir_override, project_dir_override, state) models = dbt_parser.get_models(dbt_selection) @@ -112,7 +112,11 @@ def dbt_diff( else: dbt_parser.set_connection() - with log_status_handler.status if log_status_handler else nullcontext(): + futures = {} + + with log_status_handler.status if log_status_handler else nullcontext(), ThreadPoolExecutor( + max_workers=dbt_parser.threads + ) as executor: for model in models: if log_status_handler: log_status_handler.set_prefix(f"Diffing {model.alias} \n") @@ -140,12 +144,12 @@ def dbt_diff( if diff_vars.primary_keys: if is_cloud: - diff_thread = run_as_daemon( + future = executor.submit( _cloud_diff, diff_vars, config.datasource_id, api, org_meta, log_status_handler ) - diff_threads.append(diff_thread) else: - _local_diff(diff_vars, json_output) + future = executor.submit(_local_diff, diff_vars, json_output, log_status_handler) + futures[future] = model else: if json_output: print( @@ -165,10 +169,12 @@ def dbt_diff( + "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n" ) - # wait for all threads - if diff_threads: - for thread in diff_threads: - thread.join() + for future in as_completed(futures): + model = futures[future] + try: + future.result() # if error occurred, it will be raised here + except Exception as e: + logger.error(f"An error occurred during the execution of a diff task: {model.unique_id} - {e}") _extension_notification() @@ -265,15 +271,17 @@ def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str, return prod_database, prod_schema, prod_alias -def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None: +def _local_diff( + diff_vars: TDiffVars, json_output: bool = False, log_status_handler: Optional[LogStatusHandler] = None +) -> None: + if log_status_handler: + log_status_handler.diff_started(diff_vars.dev_path[-1]) dev_qualified_str = ".".join(diff_vars.dev_path) prod_qualified_str = ".".join(diff_vars.prod_path) diff_output_str = _diff_output_base(dev_qualified_str, prod_qualified_str) - table1 = connect_to_table( - diff_vars.connection, prod_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads - ) - table2 = connect_to_table(diff_vars.connection, dev_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads) + table1 = connect_to_table(diff_vars.connection, prod_qualified_str, tuple(diff_vars.primary_keys)) + table2 = connect_to_table(diff_vars.connection, dev_qualified_str, tuple(diff_vars.primary_keys)) try: table1_columns = table1.get_schema() @@ -373,6 +381,9 @@ def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None: diff_output_str += no_differences_template() rich.print(diff_output_str) + if log_status_handler: + log_status_handler.diff_finished(diff_vars.dev_path[-1]) + def _initialize_api() -> Optional[DatafoldAPI]: datafold_host = os.environ.get("DATAFOLD_HOST") @@ -406,7 +417,7 @@ def _cloud_diff( log_status_handler: Optional[LogStatusHandler] = None, ) -> None: if log_status_handler: - log_status_handler.cloud_diff_started(diff_vars.dev_path[-1]) + log_status_handler.diff_started(diff_vars.dev_path[-1]) diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path)) payload = TCloudApiDataDiff( data_source1_id=datasource_id, @@ -476,7 +487,7 @@ def _cloud_diff( rich.print(diff_output_str) if log_status_handler: - log_status_handler.cloud_diff_finished(diff_vars.dev_path[-1]) + log_status_handler.diff_finished(diff_vars.dev_path[-1]) except BaseException as ex: # Catch KeyboardInterrupt too error = ex finally: diff --git a/data_diff/dbt_parser.py b/data_diff/dbt_parser.py index 4b6124d5..0d864a57 100644 --- a/data_diff/dbt_parser.py +++ b/data_diff/dbt_parser.py @@ -446,17 +446,17 @@ def get_pk_from_model(self, node, unique_columns: dict, pk_tag: str) -> List[str from_meta = [name for name, params in node.columns.items() if pk_tag in params.meta] or None if from_meta: - logger.debug("Found PKs via META: " + str(from_meta)) + logger.debug(f"Found PKs via META [{node.name}]: " + str(from_meta)) return from_meta from_tags = [name for name, params in node.columns.items() if pk_tag in params.tags] or None if from_tags: - logger.debug("Found PKs via Tags: " + str(from_tags)) + logger.debug(f"Found PKs via Tags [{node.name}]: " + str(from_tags)) return from_tags if node.unique_id in unique_columns: from_uniq = unique_columns.get(node.unique_id) if from_uniq is not None: - logger.debug("Found PKs via Uniqueness tests: " + str(from_uniq)) + logger.debug(f"Found PKs via Uniqueness tests [{node.name}]: {str(from_uniq)}") return list(from_uniq) except (KeyError, IndexError, TypeError) as e: diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 6fadc5d8..8e7fcf30 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -162,7 +162,7 @@ def _diff_tables_root(self, table1: TableSegment, table2: TableSegment, info_tre yield from self._diff_segments(None, table1, table2, info_tree, None) else: yield from self._bisect_and_diff_tables(table1, table2, info_tree) - logger.info("Diffing complete") + logger.info(f"Diffing complete: {table1.table_path} <> {table2.table_path}") if self.materialize_to_table: logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table)) @@ -193,8 +193,8 @@ def _diff_segments( partial(self._collect_stats, 1, table1, info_tree), partial(self._collect_stats, 2, table2, info_tree), partial(self._test_null_keys, table1, table2), - partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), - partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols), + partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols, table1, table2), + partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols, table1, table2), partial( self._materialize_diff, db, @@ -205,8 +205,8 @@ def _diff_segments( else None, ): assert len(a_cols) == len(b_cols) - logger.debug("Querying for different rows") - diff = db.query(diff_rows, list) + logger.debug(f"Querying for different rows: {table1.table_path}") + diff = db.query(diff_rows, list, log_message=table1.table_path) info_tree.info.set_diff(diff, schema=tuple(diff_rows.schema.items())) for is_xa, is_xb, *x in diff: if is_xa and is_xb: @@ -227,7 +227,7 @@ def _diff_segments( yield "+", tuple(b_row) def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): - logger.debug("Testing for duplicate keys") + logger.debug(f"Testing for duplicate keys: {table1.table_path} <> {table2.table_path}") # Test duplicate keys for ts in [table1, table2]: @@ -240,16 +240,16 @@ def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): unvalidated = list(set(key_columns) - set(unique)) if unvalidated: - logger.info(f"Validating that the are no duplicate keys in columns: {unvalidated}") + logger.info(f"Validating that the are no duplicate keys in columns: {unvalidated} for {ts.table_path}") # Validate that there are no duplicate keys self.stats["validated_unique_keys"] = self.stats.get("validated_unique_keys", []) + [unvalidated] q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True)) - total, total_distinct = ts.database.query(q, tuple) + total, total_distinct = ts.database.query(q, tuple, log_message=ts.table_path) if total != total_distinct: raise ValueError("Duplicate primary keys") def _test_null_keys(self, table1, table2): - logger.debug("Testing for null keys") + logger.debug(f"Testing for null keys: {table1.table_path} <> {table2.table_path}") # Test null keys for ts in [table1, table2]: @@ -257,7 +257,7 @@ def _test_null_keys(self, table1, table2): key_columns = ts.key_columns q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns)) - nulls = ts.database.query(q, list) + nulls = ts.database.query(q, list, log_message=ts.table_path) if nulls: if self.skip_null_keys: logger.warning( @@ -267,7 +267,7 @@ def _test_null_keys(self, table1, table2): raise ValueError(f"NULL values in one or more primary keys of {ts.table_path}") def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree): - logger.debug(f"Collecting stats for table #{i}") + logger.debug(f"Collecting stats for table #{i}: {table_seg.table_path}") db = table_seg.database # Metrics @@ -288,7 +288,7 @@ def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree): ) col_exprs["count"] = Count() - res = db.query(table_seg.make_select().select(**col_exprs), tuple) + res = db.query(table_seg.make_select().select(**col_exprs), tuple, log_message=table_seg.table_path) for col_name, value in safezip(col_exprs, res): if value is not None: @@ -303,7 +303,7 @@ def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree): else: self.stats[stat_name] = value - logger.debug("Done collecting stats for table #%s", i) + logger.debug("Done collecting stats for table #%s: %s", i, table_seg.table_path) def _create_outer_join(self, table1, table2): db = table1.database @@ -334,23 +334,46 @@ def _create_outer_join(self, table1, table2): diff_rows = all_rows.where(or_(this[c] == 1 for c in is_diff_cols)) return diff_rows, a_cols, b_cols, is_diff_cols, all_rows - def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): - logger.debug("Counting differences per column") - is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple) + def _count_diff_per_column( + self, + db, + diff_rows, + cols, + is_diff_cols, + table1: Optional[TableSegment] = None, + table2: Optional[TableSegment] = None, + ): + logger.info(type(table1)) + logger.debug(f"Counting differences per column: {table1.table_path} <> {table2.table_path}") + is_diff_cols_counts = db.query( + diff_rows.select(sum_(this[c]) for c in is_diff_cols), + tuple, + log_message=f"{table1.table_path} <> {table2.table_path}", + ) diff_counts = {} for name, count in safezip(cols, is_diff_cols_counts): diff_counts[name] = diff_counts.get(name, 0) + (count or 0) self.stats["diff_counts"] = diff_counts - def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): + def _sample_and_count_exclusive( + self, + db, + diff_rows, + a_cols, + b_cols, + table1: Optional[TableSegment] = None, + table2: Optional[TableSegment] = None, + ): if isinstance(db, (Oracle, MsSQL)): exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1)) else: exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) if not self.sample_exclusive_rows: - logger.debug("Counting exclusive rows") - self.stats["exclusive_count"] = db.query(exclusive_rows_query.count(), int) + logger.debug(f"Counting exclusive rows: {table1.table_path} <> {table2.table_path}") + self.stats["exclusive_count"] = db.query( + exclusive_rows_query.count(), int, log_message=f"{table1.table_path} <> {table2.table_path}" + ) return logger.info("Counting and sampling exclusive rows") diff --git a/data_diff/utils.py b/data_diff/utils.py index ee4a0f17..b9045cc1 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -485,31 +485,31 @@ def __init__(self): super().__init__() self.status = Status("") self.prefix = "" - self.cloud_diff_status = {} + self.diff_status = {} def emit(self, record): log_entry = self.format(record) - if self.cloud_diff_status: - self._update_cloud_status(log_entry) + if self.diff_status: + self._update_diff_status(log_entry) else: self.status.update(self.prefix + log_entry) def set_prefix(self, prefix_string): self.prefix = prefix_string - def cloud_diff_started(self, model_name): - self.cloud_diff_status[model_name] = "[yellow]In Progress[/]" - self._update_cloud_status() + def diff_started(self, model_name): + self.diff_status[model_name] = "[yellow]In Progress[/]" + self._update_diff_status() - def cloud_diff_finished(self, model_name): - self.cloud_diff_status[model_name] = "[green]Finished [/]" - self._update_cloud_status() + def diff_finished(self, model_name): + self.diff_status[model_name] = "[green]Finished [/]" + self._update_diff_status() - def _update_cloud_status(self, log=None): - cloud_status_string = "\n" - for model_name, status in self.cloud_diff_status.items(): - cloud_status_string += f"{status} {model_name}\n" - self.status.update(f"{cloud_status_string}{log or ''}") + def _update_diff_status(self, log=None): + status_string = "\n" + for model_name, status in self.diff_status.items(): + status_string += f"{status} {model_name}\n" + self.status.update(f"{status_string}{log or ''}") class UnknownMeta(type): diff --git a/tests/test_dbt.py b/tests/test_dbt.py index c281b6fb..31af99eb 100644 --- a/tests/test_dbt.py +++ b/tests/test_dbt.py @@ -93,8 +93,8 @@ def test_local_diff(self, mock_diff_tables): ) self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2) self.assertEqual(mock_connect.call_count, 2) - mock_connect.assert_any_call(connection, ".".join(dev_qualified_list), tuple(expected_primary_keys), threads) - mock_connect.assert_any_call(connection, ".".join(prod_qualified_list), tuple(expected_primary_keys), threads) + mock_connect.assert_any_call(connection, ".".join(dev_qualified_list), tuple(expected_primary_keys)) + mock_connect.assert_any_call(connection, ".".join(prod_qualified_list), tuple(expected_primary_keys)) mock_diff.get_stats_string.assert_called_once() @patch("data_diff.dbt.diff_tables") @@ -180,8 +180,8 @@ def test_local_diff_no_diffs(self, mock_diff_tables): ) self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2) self.assertEqual(mock_connect.call_count, 2) - mock_connect.assert_any_call(connection, ".".join(dev_qualified_list), tuple(expected_primary_keys), None) - mock_connect.assert_any_call(connection, ".".join(prod_qualified_list), tuple(expected_primary_keys), None) + mock_connect.assert_any_call(connection, ".".join(dev_qualified_list), tuple(expected_primary_keys)) + mock_connect.assert_any_call(connection, ".".join(prod_qualified_list), tuple(expected_primary_keys)) mock_diff.get_stats_string.assert_not_called() @patch("data_diff.dbt.rich.print") @@ -248,6 +248,7 @@ def test_diff_is_cloud( where = "a_string" config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema", datasource_id=1) mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_model = Mock() mock_api.get_data_source.return_value = TCloudApiDataSource(id=1, type="snowflake", name="snowflake") mock_initialize_api.return_value = mock_api @@ -386,6 +387,7 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] @@ -407,7 +409,7 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m mock_dbt_parser_inst.get_models.assert_called_once() mock_dbt_parser_inst.set_connection.assert_called_once() mock_cloud_diff.assert_not_called() - mock_local_diff.assert_called_once_with(diff_vars, False) + mock_local_diff.assert_called_once_with(diff_vars, False, None) mock_print.assert_not_called() @patch("data_diff.dbt._get_diff_vars") @@ -423,6 +425,7 @@ def test_diff_state_model_dne( threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] @@ -460,6 +463,7 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] @@ -481,7 +485,7 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m mock_dbt_parser_inst.get_models.assert_called_once() mock_dbt_parser_inst.set_connection.assert_called_once() mock_cloud_diff.assert_not_called() - mock_local_diff.assert_called_once_with(diff_vars, False) + mock_local_diff.assert_called_once_with(diff_vars, False, None) mock_print.assert_not_called() @patch("data_diff.dbt._get_diff_vars") @@ -497,6 +501,7 @@ def test_diff_only_prod_schema( threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] @@ -518,7 +523,7 @@ def test_diff_only_prod_schema( mock_dbt_parser_inst.get_models.assert_called_once() mock_dbt_parser_inst.set_connection.assert_called_once() mock_cloud_diff.assert_not_called() - mock_local_diff.assert_called_once_with(diff_vars, False) + mock_local_diff.assert_called_once_with(diff_vars, False, None) mock_print.assert_not_called() @patch("data_diff.dbt._initialize_api") @@ -543,6 +548,7 @@ def test_diff_is_cloud_no_pks( mock_model = Mock() connection = {} threads = None + mock_dbt_parser_inst.threads = threads where = "a_string" config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema", datasource_id=1) mock_api = Mock() @@ -584,6 +590,7 @@ def test_diff_not_is_cloud_no_pks( threads = None where = "a_string" mock_dbt_parser_inst = Mock() + mock_dbt_parser_inst.threads = threads mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model]