diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_config.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_config.py index 6d6ba601556788..473cf9f907ba84 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_config.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_config.py @@ -1,14 +1,18 @@ from enum import Enum -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Union from airflow.configuration import conf +from pydantic import root_validator from pydantic.fields import Field import datahub.emitter.mce_builder as builder from datahub.configuration.common import AllowDenyPattern, ConfigModel if TYPE_CHECKING: - from datahub_airflow_plugin.hooks.datahub import DatahubGenericHook + from datahub_airflow_plugin.hooks.datahub import ( + DatahubCompositeHook, + DatahubGenericHook, + ) class DatajobUrl(Enum): @@ -27,6 +31,8 @@ class DatahubLineageConfig(ConfigModel): # DataHub hook connection ID. datahub_conn_id: str + _datahub_connection_ids: List[str] + # Cluster to associate with the pipelines and tasks. Defaults to "prod". cluster: str = builder.DEFAULT_FLOW_CLUSTER @@ -68,11 +74,25 @@ class DatahubLineageConfig(ConfigModel): disable_openlineage_plugin: bool = True - def make_emitter_hook(self) -> "DatahubGenericHook": + def make_emitter_hook(self) -> Union["DatahubGenericHook", "DatahubCompositeHook"]: # This is necessary to avoid issues with circular imports. - from datahub_airflow_plugin.hooks.datahub import DatahubGenericHook - - return DatahubGenericHook(self.datahub_conn_id) + from datahub_airflow_plugin.hooks.datahub import ( + DatahubCompositeHook, + DatahubGenericHook, + ) + + if len(self._datahub_connection_ids) == 1: + return DatahubGenericHook(self._datahub_connection_ids[0]) + else: + return DatahubCompositeHook(self._datahub_connection_ids) + + @root_validator(skip_on_failure=True) + def split_conn_ids(cls, values: Dict) -> Dict: + if not values.get("datahub_conn_id"): + raise ValueError("datahub_conn_id is required") + conn_ids = values.get("datahub_conn_id", "").split(",") + cls._datahub_connection_ids = [conn_id.strip() for conn_id in conn_ids] + return values def get_lineage_config() -> DatahubLineageConfig: diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin_v22.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin_v22.py index 99b0a40fd3c13e..b7e6e109e8567c 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin_v22.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin_v22.py @@ -29,6 +29,11 @@ TASK_ON_RETRY_CALLBACK = "on_retry_callback" +def load_config_v22(): + plugin_config = get_lineage_config() + return plugin_config + + def get_task_inlets_advanced(task: BaseOperator, context: Any) -> Iterable[Any]: # TODO: Fix for https://github.com/apache/airflow/commit/1b1f3fabc5909a447a6277cafef3a0d4ef1f01ae # in Airflow 2.4. @@ -99,9 +104,7 @@ def datahub_task_status_callback(context, status): task_inlets = get_task_inlets_advanced(task, context) task_outlets = get_task_outlets(task) - emitter = ( - DatahubGenericHook(config.datahub_conn_id).get_underlying_hook().make_emitter() - ) + emitter = config.make_emitter_hook().make_emitter() dataflow = AirflowGenerator.generate_dataflow( config=config, @@ -217,7 +220,7 @@ def datahub_pre_execution(context): def _wrap_pre_execution(pre_execution): def custom_pre_execution(context): - config = get_lineage_config() + config = load_config_v22() if config.enabled: context["_datahub_config"] = config datahub_pre_execution(context) @@ -231,7 +234,7 @@ def custom_pre_execution(context): def _wrap_on_failure_callback(on_failure_callback): def custom_on_failure_callback(context): - config = get_lineage_config() + config = load_config_v22() if config.enabled: context["_datahub_config"] = config try: @@ -251,7 +254,7 @@ def custom_on_failure_callback(context): def _wrap_on_success_callback(on_success_callback): def custom_on_success_callback(context): - config = get_lineage_config() + config = load_config_v22() if config.enabled: context["_datahub_config"] = config try: @@ -271,7 +274,8 @@ def custom_on_success_callback(context): def _wrap_on_retry_callback(on_retry_callback): def custom_on_retry_callback(context): - config = get_lineage_config() + config = load_config_v22() + if config.enabled: context["_datahub_config"] = config try: @@ -363,7 +367,7 @@ def _patch_datahub_policy(): _patch_policy(settings) - plugin_config = get_lineage_config() + plugin_config = load_config_v22() telemetry.telemetry_instance.ping( "airflow-plugin-init", { diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/hooks/datahub.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/hooks/datahub.py index 26c5026c075bd7..e2e4c9fef6f16e 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/hooks/datahub.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/hooks/datahub.py @@ -1,8 +1,9 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook +from datahub.emitter.composite_emitter import CompositeEmitter from datahub.emitter.generic_emitter import Emitter from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.metadata.com.linkedin.pegasus2avro.mxe import ( @@ -303,3 +304,45 @@ def emit( # Retained for backwards compatibility. emit_mces = emit + + +class DatahubCompositeHook(BaseHook): + """ + A hook that can emit metadata to multiple DataHub instances. + + :param datahub_conn_ids: References to the DataHub connections. + :type datahub_conn_ids: List[str] + """ + + hooks: List[DatahubGenericHook] = [] + + def __init__(self, datahub_conn_ids: List[str]) -> None: + self.datahub_conn_ids = datahub_conn_ids + + def make_emitter(self) -> CompositeEmitter: + print(f"Create emitters for {self.datahub_conn_ids}") + return CompositeEmitter( + [ + self._get_underlying_hook(conn_id).make_emitter() + for conn_id in self.datahub_conn_ids + ] + ) + + def emit( + self, + items: Sequence[ + Union[ + MetadataChangeEvent, + MetadataChangeProposal, + MetadataChangeProposalWrapper, + ] + ], + ) -> None: + emitter = self.make_emitter() + + for item in items: + print(f"emitting item {item}") + emitter.emit(item) + + def _get_underlying_hook(self, conn_id: str) -> DatahubGenericHook: + return DatahubGenericHook(conn_id) diff --git a/metadata-ingestion-modules/airflow-plugin/tests/integration/test_plugin.py b/metadata-ingestion-modules/airflow-plugin/tests/integration/test_plugin.py index 2744c26021cde3..8d0bc9748ff5fd 100644 --- a/metadata-ingestion-modules/airflow-plugin/tests/integration/test_plugin.py +++ b/metadata-ingestion-modules/airflow-plugin/tests/integration/test_plugin.py @@ -50,6 +50,7 @@ class AirflowInstance: password: str metadata_file: pathlib.Path + metadata_file2: pathlib.Path @property def airflow_url(self) -> str: @@ -178,6 +179,7 @@ def _run_airflow( tmp_path: pathlib.Path, dags_folder: pathlib.Path, is_v1: bool, + multiple_connections: bool, ) -> Iterator[AirflowInstance]: airflow_home = tmp_path / "airflow_home" print(f"Using airflow home: {airflow_home}") @@ -189,7 +191,9 @@ def _run_airflow( print(f"Using airflow port: {airflow_port}") datahub_connection_name = "datahub_file_default" + datahub_connection_name_2 = "datahub_file_default_2" meta_file = tmp_path / "datahub_metadata.json" + meta_file2 = tmp_path / "datahub_metadata_2.json" environment = { **os.environ, @@ -204,7 +208,9 @@ def _run_airflow( "AIRFLOW__API__AUTH_BACKEND": "airflow.api.auth.backend.basic_auth", # Configure the datahub plugin and have it write the MCPs to a file. "AIRFLOW__CORE__LAZY_LOAD_PLUGINS": "False" if is_v1 else "True", - "AIRFLOW__DATAHUB__CONN_ID": datahub_connection_name, + "AIRFLOW__DATAHUB__CONN_ID": f"{datahub_connection_name}, {datahub_connection_name_2}" + if multiple_connections + else datahub_connection_name, "AIRFLOW__DATAHUB__DAG_FILTER_STR": f'{{ "deny": ["{DAG_TO_SKIP_INGESTION}"] }}', f"AIRFLOW_CONN_{datahub_connection_name.upper()}": Connection( conn_id="datahub_file_default", @@ -251,6 +257,13 @@ def _run_airflow( "SQLALCHEMY_SILENCE_UBER_WARNING": "1", } + if multiple_connections: + environment[f"AIRFLOW_CONN_{datahub_connection_name_2.upper()}"] = Connection( + conn_id="datahub_file_default2", + conn_type="datahub-file", + host=str(meta_file2), + ).get_uri() + if not HAS_AIRFLOW_STANDALONE_CMD: raise pytest.skip("Airflow standalone command is not available") @@ -315,6 +328,7 @@ def _run_airflow( username=airflow_username, password=airflow_password, metadata_file=meta_file, + metadata_file2=meta_file2, ) yield airflow_instance @@ -355,10 +369,11 @@ class DagTestCase: success: bool = True v2_only: bool = False + multiple_connections: bool = False test_cases = [ - DagTestCase("simple_dag"), + DagTestCase("simple_dag", multiple_connections=True), DagTestCase("basic_iolets"), DagTestCase("dag_to_skip", v2_only=True), DagTestCase("snowflake_operator", success=False, v2_only=True), @@ -441,7 +456,10 @@ def test_airflow_plugin( dag_id = test_case.dag_id with _run_airflow( - tmp_path, dags_folder=DAGS_FOLDER, is_v1=is_v1 + tmp_path, + dags_folder=DAGS_FOLDER, + is_v1=is_v1, + multiple_connections=test_case.multiple_connections, ) as airflow_instance: print(f"Running DAG {dag_id}...") _wait_for_dag_to_load(airflow_instance, dag_id) @@ -491,6 +509,21 @@ def test_airflow_plugin( ], ) + if test_case.multiple_connections: + _sanitize_output_file(airflow_instance.metadata_file2) + check_golden_file( + pytestconfig=pytestconfig, + output_path=airflow_instance.metadata_file2, + golden_path=golden_path, + ignore_paths=[ + # TODO: If we switched to Git urls, maybe we could get this to work consistently. + r"root\[\d+\]\['aspect'\]\['json'\]\['customProperties'\]\['datahub_sql_parser_error'\]", + r"root\[\d+\]\['aspect'\]\['json'\]\['customProperties'\]\['openlineage_.*'\]", + r"root\[\d+\]\['aspect'\]\['json'\]\['customProperties'\]\['log_url'\]", + r"root\[\d+\]\['aspect'\]\['json'\]\['externalUrl'\]", + ], + ) + def _sanitize_output_file(output_path: pathlib.Path) -> None: # Overwrite some custom properties in the output file to make it easier to compare. @@ -534,6 +567,7 @@ def _sanitize(obj: Any) -> None: tmp_path=pathlib.Path(tempfile.mkdtemp("airflow-plugin-test")), dags_folder=DAGS_FOLDER, is_v1=not HAS_AIRFLOW_LISTENER_API, + multiple_connections=False, ) as airflow_instance: # input("Press enter to exit...") breakpoint() diff --git a/metadata-ingestion/src/datahub/emitter/composite_emitter.py b/metadata-ingestion/src/datahub/emitter/composite_emitter.py new file mode 100644 index 00000000000000..fe35f74be11cb3 --- /dev/null +++ b/metadata-ingestion/src/datahub/emitter/composite_emitter.py @@ -0,0 +1,36 @@ +from typing import Callable, List, Optional, Union + +from datahub.emitter.generic_emitter import Emitter +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.metadata.com.linkedin.pegasus2avro.mxe import ( + MetadataChangeEvent, + MetadataChangeProposal, +) + + +# Experimental composite emitter that allows multiple emitters to be used in a single ingestion job +class CompositeEmitter(Emitter): + def __init__(self, emitters: List[Emitter]) -> None: + self.emitters = emitters + + def emit( + self, + item: Union[ + MetadataChangeEvent, + MetadataChangeProposal, + MetadataChangeProposalWrapper, + ], + callback: Optional[Callable[[Exception, str], None]] = None, + ) -> None: + callback_called = False + for emitter in self.emitters: + if not callback_called: + # We want to ensure that the callback is only called once and we tie it to the first emitter + emitter.emit(item, callback) + callback_called = True + else: + emitter.emit(item) + + def flush(self) -> None: + for emitter in self.emitters: + emitter.flush() diff --git a/metadata-ingestion/tests/unit/datahub/emitter/test_composite_emitter.py b/metadata-ingestion/tests/unit/datahub/emitter/test_composite_emitter.py new file mode 100644 index 00000000000000..fdab1224d77d4a --- /dev/null +++ b/metadata-ingestion/tests/unit/datahub/emitter/test_composite_emitter.py @@ -0,0 +1,42 @@ +from unittest.mock import MagicMock + +import pytest + +from datahub.emitter.composite_emitter import CompositeEmitter +from datahub.emitter.generic_emitter import Emitter +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.metadata.com.linkedin.pegasus2avro.dataset import DatasetProfile + + +@pytest.fixture +def mock_emitters(): + return [MagicMock(spec=Emitter), MagicMock(spec=Emitter)] + + +def test_composite_emitter_emit(mock_emitters): + composite_emitter = CompositeEmitter(mock_emitters) + item = MetadataChangeProposalWrapper( + entityUrn="urn:li:dataset:(urn:li:dataPlatform:mysql,User.UserAccount,PROD)", + aspect=DatasetProfile( + rowCount=2000, + columnCount=15, + timestampMillis=1626995099686, + ), + ) + callback = MagicMock() + + composite_emitter.emit(item, callback) + + mock_emitters[0].emit.assert_called_once_with(item, callback) + mock_emitters[1].emit.assert_called_once_with(item) + assert mock_emitters[0].emit.call_count == 1 + assert mock_emitters[1].emit.call_count == 1 + + +def test_composite_emitter_flush(mock_emitters): + composite_emitter = CompositeEmitter(mock_emitters) + + composite_emitter.flush() + + for emitter in mock_emitters: + emitter.flush.assert_called_once()