Skip to content

Commit

Permalink
fix(ingestion/airflow-plugin): incorporated review comments and added…
Browse files Browse the repository at this point in the history
… the test
  • Loading branch information
dushayntAW committed Oct 8, 2024
1 parent aba3d0f commit c87c070
Show file tree
Hide file tree
Showing 5 changed files with 502 additions and 110 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from enum import Enum
from typing import TYPE_CHECKING, Optional

from pydantic.fields import Field

import datahub.emitter.mce_builder as builder
from airflow.configuration import conf
from datahub.configuration.common import ConfigModel
from datahub.configuration.common import ConfigModel, AllowDenyPattern

if TYPE_CHECKING:
from datahub_airflow_plugin.hooks.datahub import DatahubGenericHook
Expand Down Expand Up @@ -56,7 +58,10 @@ class DatahubLineageConfig(ConfigModel):

datajob_url_link: DatajobUrl = DatajobUrl.TASKINSTANCE

dag_allow_deny_pattern_str: str = '{"allow": [".*"]}'
dag_allow_deny_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="regex patterns for DAGs to ingest",
)

def make_emitter_hook(self) -> "DatahubGenericHook":
# This is necessary to avoid issues with circular imports.
Expand Down Expand Up @@ -89,8 +94,8 @@ def get_lineage_config() -> DatahubLineageConfig:
datajob_url_link = conf.get(
"datahub", "datajob_url_link", fallback=DatajobUrl.TASKINSTANCE.value
)
dag_allow_deny_pattern_str = conf.get(
"datahub", "dag_allow_deny_pattern_str", fallback='{"allow": [".*"]}'
dag_allow_deny_pattern = AllowDenyPattern.parse_raw(
conf.get("datahub", "dag_allow_deny_pattern", fallback='{"allow": [".*"]}')
)

return DatahubLineageConfig(
Expand All @@ -107,5 +112,5 @@ def get_lineage_config() -> DatahubLineageConfig:
debug_emitter=debug_emitter,
disable_openlineage_plugin=disable_openlineage_plugin,
datajob_url_link=datajob_url_link,
dag_allow_deny_pattern_str=dag_allow_deny_pattern_str,
dag_allow_deny_pattern=dag_allow_deny_pattern,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from airflow.models.serialized_dag import SerializedDagModel
from datahub.api.entities.datajob import DataJob
from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult
from datahub.configuration.common import AllowDenyPattern
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.rest_emitter import DatahubRestEmitter
from datahub.ingestion.graph.client import DataHubGraph
Expand Down Expand Up @@ -384,95 +383,100 @@ def on_task_instance_running(
return

logger.debug(
f"DataHub listener got notification about task instance start for {task_instance.task_id}"
f"DataHub listener got notification about task instance start for {task_instance.task_id} of dag {task_instance.dag_run.dag_id}"
)

task_instance = _render_templates(task_instance)
if self.config.dag_allow_deny_pattern.allowed(task_instance.dag_run.dag_id):
task_instance = _render_templates(task_instance)

# The type ignore is to placate mypy on Airflow 2.1.x.
dagrun: "DagRun" = task_instance.dag_run # type: ignore[attr-defined]
task = task_instance.task
assert task is not None
dag: "DAG" = task.dag # type: ignore[assignment]
# The type ignore is to placate mypy on Airflow 2.1.x.
dagrun: "DagRun" = task_instance.dag_run # type: ignore[attr-defined]
task = task_instance.task
assert task is not None
dag: "DAG" = task.dag # type: ignore[assignment]

self._task_holder.set_task(task_instance)
self._task_holder.set_task(task_instance)

# Handle async operators in Airflow 2.3 by skipping deferred state.
# Inspired by https://github.com/OpenLineage/OpenLineage/pull/1601
if task_instance.next_method is not None: # type: ignore[attr-defined]
return
# Handle async operators in Airflow 2.3 by skipping deferred state.
# Inspired by https://github.com/OpenLineage/OpenLineage/pull/1601
if task_instance.next_method is not None: # type: ignore[attr-defined]
return

# If we don't have the DAG listener API, we just pretend that
# the start of the task is the start of the DAG.
# This generates duplicate events, but it's better than not
# generating anything.
if not HAS_AIRFLOW_DAG_LISTENER_API:
self.on_dag_start(dagrun)
# If we don't have the DAG listener API, we just pretend that
# the start of the task is the start of the DAG.
# This generates duplicate events, but it's better than not
# generating anything.
if not HAS_AIRFLOW_DAG_LISTENER_API:
self.on_dag_start(dagrun)

datajob = AirflowGenerator.generate_datajob(
cluster=self.config.cluster,
task=task,
dag=dag,
capture_tags=self.config.capture_tags_info,
capture_owner=self.config.capture_ownership_info,
config=self.config,
)
datajob = AirflowGenerator.generate_datajob(
cluster=self.config.cluster,
task=task,
dag=dag,
capture_tags=self.config.capture_tags_info,
capture_owner=self.config.capture_ownership_info,
config=self.config,
)

# TODO: Make use of get_task_location to extract github urls.
# TODO: Make use of get_task_location to extract github urls.

# Add lineage info.
self._extract_lineage(datajob, dagrun, task, task_instance)
# Add lineage info.
self._extract_lineage(datajob, dagrun, task, task_instance)

# TODO: Add handling for Airflow mapped tasks using task_instance.map_index
# TODO: Add handling for Airflow mapped tasks using task_instance.map_index

for mcp in datajob.generate_mcp(
materialize_iolets=self.config.materialize_iolets
):
self.emitter.emit(mcp, self._make_emit_callback())
logger.debug(f"Emitted DataHub Datajob start: {datajob}")
for mcp in datajob.generate_mcp(
materialize_iolets=self.config.materialize_iolets
):
self.emitter.emit(mcp, self._make_emit_callback())
logger.debug(f"Emitted DataHub Datajob start: {datajob}")

if self.config.capture_executions:
dpi = AirflowGenerator.run_datajob(
emitter=self.emitter,
config=self.config,
ti=task_instance,
dag=dag,
dag_run=dagrun,
datajob=datajob,
emit_templates=False,
)
logger.debug(f"Emitted DataHub DataProcess Instance start: {dpi}")
if self.config.capture_executions:
dpi = AirflowGenerator.run_datajob(
emitter=self.emitter,
config=self.config,
ti=task_instance,
dag=dag,
dag_run=dagrun,
datajob=datajob,
emit_templates=False,
)
logger.debug(f"Emitted DataHub DataProcess Instance start: {dpi}")

self.emitter.flush()
self.emitter.flush()

logger.debug(
f"DataHub listener finished processing notification about task instance start for {task_instance.task_id}"
)
logger.debug(
f"DataHub listener finished processing notification about task instance start for {task_instance.task_id}"
)

if self.config.materialize_iolets:
for outlet in datajob.outlets:
reported_time: int = int(time.time() * 1000)
operation = OperationClass(
timestampMillis=reported_time,
operationType=OperationTypeClass.CREATE,
lastUpdatedTimestamp=reported_time,
actor=builder.make_user_urn("airflow"),
)
if self.config.materialize_iolets:
for outlet in datajob.outlets:
reported_time: int = int(time.time() * 1000)
operation = OperationClass(
timestampMillis=reported_time,
operationType=OperationTypeClass.CREATE,
lastUpdatedTimestamp=reported_time,
actor=builder.make_user_urn("airflow"),
)

operation_mcp = MetadataChangeProposalWrapper(
entityUrn=str(outlet), aspect=operation
)
operation_mcp = MetadataChangeProposalWrapper(
entityUrn=str(outlet), aspect=operation
)

self.emitter.emit(operation_mcp)
logger.debug(f"Emitted Dataset Operation: {outlet}")
self.emitter.emit(operation_mcp)
logger.debug(f"Emitted Dataset Operation: {outlet}")
else:
if self.graph:
for outlet in datajob.outlets:
if not self.graph.exists(str(outlet)):
logger.warning(f"Dataset {str(outlet)} not materialized")
for inlet in datajob.inlets:
if not self.graph.exists(str(inlet)):
logger.warning(f"Dataset {str(inlet)} not materialized")
else:
if self.graph:
for outlet in datajob.outlets:
if not self.graph.exists(str(outlet)):
logger.warning(f"Dataset {str(outlet)} not materialized")
for inlet in datajob.inlets:
if not self.graph.exists(str(inlet)):
logger.warning(f"Dataset {str(inlet)} not materialized")
logger.debug(
f"DAG {task_instance.dag_run.dag_id} is not allowed by the pattern"
)

def on_task_instance_finish(
self, task_instance: "TaskInstance", status: InstanceRunResult
Expand All @@ -491,40 +495,45 @@ def on_task_instance_finish(

dag: "DAG" = task.dag # type: ignore[assignment]

datajob = AirflowGenerator.generate_datajob(
cluster=self.config.cluster,
task=task,
dag=dag,
capture_tags=self.config.capture_tags_info,
capture_owner=self.config.capture_ownership_info,
config=self.config,
)

# Add lineage info.
self._extract_lineage(datajob, dagrun, task, task_instance, complete=True)

for mcp in datajob.generate_mcp(
materialize_iolets=self.config.materialize_iolets
):
self.emitter.emit(mcp, self._make_emit_callback())
logger.debug(f"Emitted DataHub Datajob finish w/ status {status}: {datajob}")

if self.config.capture_executions:
dpi = AirflowGenerator.complete_datajob(
emitter=self.emitter,
if self.config.dag_allow_deny_pattern.allowed(dag.dag_id):
datajob = AirflowGenerator.generate_datajob(
cluster=self.config.cluster,
ti=task_instance,
task=task,
dag=dag,
dag_run=dagrun,
datajob=datajob,
result=status,
capture_tags=self.config.capture_tags_info,
capture_owner=self.config.capture_ownership_info,
config=self.config,
)

# Add lineage info.
self._extract_lineage(datajob, dagrun, task, task_instance, complete=True)

for mcp in datajob.generate_mcp(
materialize_iolets=self.config.materialize_iolets
):
self.emitter.emit(mcp, self._make_emit_callback())
logger.debug(
f"Emitted DataHub DataProcess Instance with status {status}: {dpi}"
f"Emitted DataHub Datajob finish w/ status {status}: {datajob}"
)

self.emitter.flush()
if self.config.capture_executions:
dpi = AirflowGenerator.complete_datajob(
emitter=self.emitter,
cluster=self.config.cluster,
ti=task_instance,
dag=dag,
dag_run=dagrun,
datajob=datajob,
result=status,
config=self.config,
)
logger.debug(
f"Emitted DataHub DataProcess Instance with status {status}: {dpi}"
)

self.emitter.flush()
else:
logger.debug(f"DAG {dag.dag_id} is not allowed by the pattern")

@hookimpl
@run_in_thread
Expand Down Expand Up @@ -688,15 +697,12 @@ def on_dag_run_running(self, dag_run: "DagRun", msg: str) -> None:
f"DataHub listener got notification about dag run start for {dag_run.dag_id}"
)

# convert allow_deny_pattern string to AllowDenyPattern object
dag_allow_deny_pattern_model = AllowDenyPattern.parse_raw(
self.config.dag_allow_deny_pattern_str
)

assert dag_run.dag_id
if dag_allow_deny_pattern_model.allowed(dag_run.dag_id):
if self.config.dag_allow_deny_pattern.allowed(dag_run.dag_id):
self.on_dag_start(dag_run)
self.emitter.flush()
else:
logger.debug(f"DAG {dag_run.dag_id} is not allowed by the pattern")

# TODO: Add hooks for on_dag_run_success, on_dag_run_failed -> call AirflowGenerator.complete_dataflow

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from datetime import datetime

from airflow import DAG
from airflow.operators.bash import BashOperator

from datahub_airflow_plugin.entities import Dataset, Urn

with DAG(
"dag_to_filter_from_ingestion",
start_date=datetime(2023, 1, 1),
schedule_interval=None,
catchup=False,
) as dag:
task1 = BashOperator(
task_id="task_dag_to_filter_from_ingestion_task_1",
dag=dag,
bash_command="echo 'task_dag_to_filter_from_ingestion_task_1'",
inlets=[
Dataset(platform="snowflake", name="mydb.schema.tableA"),
Urn(
"urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableC,PROD)"
),
Urn("urn:li:dataJob:(urn:li:dataFlow:(airflow,test_dag,PROD),test_task)"),
],
outlets=[Dataset("snowflake", "mydb.schema.tableD")],
)

task2 = BashOperator(
task_id="task_dag_to_filter_from_ingestion_task_2",
dag=dag,
bash_command="echo 'task_dag_to_filter_from_ingestion_task_2'",
)

task1 >> task2
Loading

0 comments on commit c87c070

Please sign in to comment.