Skip to content

Commit

Permalink
feat: Add OpenLineage support for CloudSQLExecuteQueryOperator and so…
Browse files Browse the repository at this point in the history
…me SQLtoGCS transfer operators

Signed-off-by: Kacper Muda <[email protected]>
  • Loading branch information
kacpermuda committed Dec 24, 2024
1 parent f56038e commit 2782f2f
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 13 deletions.
48 changes: 48 additions & 0 deletions providers/src/airflow/providers/google/cloud/openlineage/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,59 @@
SchemaDatasetFacet,
)
from airflow.providers.google.cloud.openlineage.utils import BigQueryJobRunFacet
from airflow.providers.openlineage.extractors import OperatorLineage


BIGQUERY_NAMESPACE = "bigquery"


class _SQLOpenLineageMixin:
@staticmethod
def _get_openlineage_facets(hook, sql, conn_id, database, logger) -> OperatorLineage | None:
try:
from airflow.providers.openlineage.sqlparser import SQLParser
except ImportError:
return None

try:
from airflow.providers.openlineage.utils.utils import should_use_external_connection

use_external_connection = should_use_external_connection(hook)
except ImportError:
# OpenLineage provider release < 1.8.0 - we always use connection
use_external_connection = True

connection = hook.get_connection(conn_id)
try:
database_info = hook.get_openlineage_database_info(connection)
except AttributeError:
logger.debug("%s has no database info provided", hook)
database_info = None

if database_info is None:
return None

try:
sql_parser = SQLParser(
dialect=hook.get_openlineage_database_dialect(connection),
default_schema=hook.get_openlineage_default_schema(),
)
except AttributeError:
logger.debug("%s failed to get database dialect", hook)
return None

operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
sql=sql,
hook=hook,
database_info=database_info,
database=database,
sqlalchemy_engine=hook.get_sqlalchemy_engine(),
use_connection=use_external_connection,
)

return operator_lineage


class _BigQueryOpenLineageMixin:
def get_openlineage_facets_on_complete(self, _):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
from contextlib import contextmanager
from functools import cached_property
from typing import TYPE_CHECKING, Any

Expand All @@ -30,6 +31,7 @@
from airflow.hooks.base import BaseHook
from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook
from airflow.providers.google.cloud.links.cloud_sql import CloudSQLInstanceDatabaseLink, CloudSQLInstanceLink
from airflow.providers.google.cloud.openlineage.mixins import _SQLOpenLineageMixin
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.cloud_sql import CloudSQLExportTrigger
from airflow.providers.google.cloud.utils.field_validator import GcpBodyFieldValidator
Expand All @@ -38,8 +40,7 @@

if TYPE_CHECKING:
from airflow.models import Connection
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils.context import Context


Expand Down Expand Up @@ -1167,7 +1168,7 @@ def execute(self, context: Context) -> None:
return hook.import_instance(project_id=self.project_id, instance=self.instance, body=self.body)


class CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator):
class CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator, _SQLOpenLineageMixin):
"""
Perform DML or DDL query on an existing Cloud Sql instance.
Expand Down Expand Up @@ -1256,7 +1257,8 @@ def __init__(
self.ssl_client_key = ssl_client_key
self.ssl_secret_id = ssl_secret_id

def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook | MySqlHook) -> None:
@contextmanager
def cloud_sql_proxy_context(self, hook: CloudSQLDatabaseHook):
cloud_sql_proxy_runner = None
try:
if hook.use_proxy:
Expand All @@ -1266,8 +1268,7 @@ def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook
# be taken over here by another bind(0).
# It's quite unlikely to happen though!
cloud_sql_proxy_runner.start_proxy()
self.log.info('Executing: "%s"', self.sql)
database_hook.run(self.sql, self.autocommit, parameters=self.parameters)
yield
finally:
if cloud_sql_proxy_runner:
cloud_sql_proxy_runner.stop_proxy()
Expand All @@ -1281,7 +1282,9 @@ def execute(self, context: Context):
hook.validate_socket_path_length()
database_hook = hook.get_database_hook(connection=connection)
try:
self._execute_query(hook, database_hook)
with self.cloud_sql_proxy_context(hook):
self.log.info('Executing: "%s"', self.sql)
database_hook.run(self.sql, self.autocommit, parameters=self.parameters)
finally:
hook.cleanup_database_hook()

Expand All @@ -1297,3 +1300,13 @@ def hook(self):
ssl_key=self.ssl_client_key,
ssl_secret_id=self.ssl_secret_id,
)

def get_openlineage_facets_on_complete(self, _) -> OperatorLineage | None:
with self.cloud_sql_proxy_context(self.hook):
return self._get_openlineage_facets(
hook=self.hook.db_hook,
sql=self.sql,
conn_id=self.gcp_cloudsql_conn_id,
database=self.hook.database,
logger=self.log,
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import base64
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from functools import cached_property
from typing import TYPE_CHECKING

try:
from MySQLdb.constants import FIELD_TYPE
Expand All @@ -34,11 +36,15 @@
)


from airflow.providers.google.cloud.openlineage.mixins import _SQLOpenLineageMixin
from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage

class MySQLToGCSOperator(BaseSQLToGCSOperator):

class MySQLToGCSOperator(BaseSQLToGCSOperator, _SQLOpenLineageMixin):
"""
Copy data from MySQL to Google Cloud Storage in JSON, CSV or Parquet format.
Expand Down Expand Up @@ -77,10 +83,13 @@ def __init__(self, *, mysql_conn_id="mysql_default", ensure_utc=False, **kwargs)
self.mysql_conn_id = mysql_conn_id
self.ensure_utc = ensure_utc

@cached_property
def db_hook(self) -> MySqlHook:
return MySqlHook(mysql_conn_id=self.mysql_conn_id)

def query(self):
"""Query mysql and returns a cursor to the results."""
mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
conn = mysql.get_conn()
conn = self.db_hook.get_conn()
cursor = conn.cursor()
if self.ensure_utc:
# Ensure TIMESTAMP results are in UTC
Expand Down Expand Up @@ -140,3 +149,16 @@ def convert_type(self, value, schema_type: str, **kwargs):
else:
value = base64.standard_b64encode(value).decode("ascii")
return value

def get_openlineage_facets_on_complete(self, _) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import SQLJobFacet
from airflow.providers.openlineage.extractors import OperatorLineage

sql_parsing_result = self._get_openlineage_facets(
hook=self.db_hook, sql=self.sql, conn_id=self.mysql_conn_id, database=None, logger=self.log
)
gcs_output_datasets = self._get_openlineage_output_datasets()
if sql_parsing_result:
sql_parsing_result.outputs = gcs_output_datasets
return sql_parsing_result
return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)})
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,19 @@
import time
import uuid
from decimal import Decimal
from functools import cached_property
from typing import TYPE_CHECKING

import pendulum
from slugify import slugify

from airflow.providers.google.cloud.openlineage.mixins import _SQLOpenLineageMixin
from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
from airflow.providers.postgres.hooks.postgres import PostgresHook

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage


class _PostgresServerSideCursorDecorator:
"""
Expand Down Expand Up @@ -67,7 +73,7 @@ def description(self):
return self.cursor.description


class PostgresToGCSOperator(BaseSQLToGCSOperator):
class PostgresToGCSOperator(BaseSQLToGCSOperator, _SQLOpenLineageMixin):
"""
Copy data from Postgres to Google Cloud Storage in JSON, CSV or Parquet format.
Expand Down Expand Up @@ -132,10 +138,13 @@ def _unique_name(self):
)
return None

@cached_property
def db_hook(self) -> PostgresHook:
return PostgresHook(postgres_conn_id=self.postgres_conn_id)

def query(self):
"""Query Postgres and returns a cursor to the results."""
hook = PostgresHook(postgres_conn_id=self.postgres_conn_id)
conn = hook.get_conn()
conn = self.db_hook.get_conn()
cursor = conn.cursor(name=self._unique_name())
cursor.execute(self.sql, self.parameters)
if self.use_server_side_cursor:
Expand Down Expand Up @@ -180,3 +189,16 @@ def convert_type(self, value, schema_type, stringify_dict=True):
if isinstance(value, Decimal):
return float(value)
return value

def get_openlineage_facets_on_complete(self, _) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import SQLJobFacet
from airflow.providers.openlineage.extractors import OperatorLineage

sql_parsing_result = self._get_openlineage_facets(
hook=self.db_hook, sql=self.sql, conn_id=self.postgres_conn_id, database=None, logger=self.log
)
gcs_output_datasets = self._get_openlineage_output_datasets()
if sql_parsing_result:
sql_parsing_result.outputs = gcs_output_datasets
return sql_parsing_result
return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)})
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.providers.google.cloud.hooks.gcs import GCSHook

if TYPE_CHECKING:
from airflow.providers.common.compat.openlineage.facet import OutputDataset
from airflow.utils.context import Context


Expand Down Expand Up @@ -151,6 +152,7 @@ def __init__(
self.partition_columns = partition_columns
self.write_on_empty = write_on_empty
self.parquet_row_group_size = parquet_row_group_size
self._uploaded_file_names: list[str] = []

def execute(self, context: Context):
if self.partition_columns:
Expand Down Expand Up @@ -501,3 +503,16 @@ def _upload_to_gcs(self, file_to_upload):
gzip=self.gzip if is_data_file else False,
metadata=metadata,
)
self._uploaded_file_names.append(object_name)

def _get_openlineage_output_datasets(self) -> list[OutputDataset]:
from airflow.providers.common.compat.openlineage.facet import OutputDataset
from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path

if self._uploaded_file_names:
unique_names = {extract_ds_name_from_gcs_path(name) for name in self._uploaded_file_names}
return [OutputDataset(namespace=f"gs://{self.bucket}", name=name) for name in unique_names]

return [
OutputDataset(namespace=f"gs://{self.bucket}", name=extract_ds_name_from_gcs_path(self.filename))
]
85 changes: 85 additions & 0 deletions providers/tests/google/cloud/operators/test_cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,19 @@

import os
from unittest import mock
from unittest.mock import MagicMock

import pytest

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models import Connection
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
SchemaDatasetFacet,
SchemaDatasetFacetFields,
SQLJobFacet,
)
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.cloud.operators.cloud_sql import (
CloudSQLCloneInstanceOperator,
CloudSQLCreateInstanceDatabaseOperator,
Expand Down Expand Up @@ -822,3 +830,80 @@ def test_create_operator_with_too_long_unix_socket_path(self, get_connection):
operator.execute(None)
err = ctx.value
assert "The UNIX socket path length cannot exceed" in str(err)


@pytest.mark.parametrize(
"connection_port, default_port, expected_port",
[(None, 4321, 4321), (1234, None, 1234), (1234, 4321, 1234)],
)
def test_execute_openlineage_events(connection_port, default_port, expected_port):
class DBApiHookForTests(DbApiHook):
conn_name_attr = "sql_default"
get_conn = MagicMock(name="conn")
get_connection = MagicMock()

def get_openlineage_database_info(self, connection):
from airflow.providers.openlineage.sqlparser import DatabaseInfo

return DatabaseInfo(
scheme="sqlscheme",
authority=DbApiHook.get_openlineage_authority_part(connection, default_port=default_port),
)

dbapi_hook = DBApiHookForTests()

class CloudSQLExecuteQueryOperatorForTest(CloudSQLExecuteQueryOperator):
@property
def hook(self):
return MagicMock(db_hook=dbapi_hook, database="")

sql = """CREATE TABLE IF NOT EXISTS popular_orders_day_of_week (
order_day_of_week VARCHAR(64) NOT NULL,
order_placed_on TIMESTAMP NOT NULL,
orders_placed INTEGER NOT NULL
);
FORGOT TO COMMENT"""
op = CloudSQLExecuteQueryOperatorForTest(task_id="task_id", sql=sql)
DB_SCHEMA_NAME = "PUBLIC"
rows = [
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_day_of_week", 1, "varchar"),
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_placed_on", 2, "timestamp"),
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, "int4"),
]
dbapi_hook.get_connection.return_value = Connection(
conn_id="sql_default", conn_type="postgresql", host="host", port=connection_port
)
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []]

lineage = op.get_openlineage_facets_on_complete(None)
assert len(lineage.inputs) == 0
assert lineage.job_facets == {"sql": SQLJobFacet(query=sql)}
assert lineage.run_facets["extractionError"].failedTasks == 1
assert lineage.outputs == [
Dataset(
namespace=f"sqlscheme://host:{expected_port}",
name="PUBLIC.popular_orders_day_of_week",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"),
SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"),
SchemaDatasetFacetFields(name="orders_placed", type="int4"),
]
)
},
)
]


def test_with_no_openlineage_provider():
import importlib

def mock__import__(name, globals_=None, locals_=None, fromlist=(), level=0):
if level == 0 and name.startswith("airflow.providers.openlineage"):
raise ImportError("No provider 'apache-airflow-providers-openlineage'")
return importlib.__import__(name, globals=globals_, locals=locals_, fromlist=fromlist, level=level)

with mock.patch("builtins.__import__", side_effect=mock__import__):
op = CloudSQLExecuteQueryOperator(task_id="task_id", sql="SELECT 1;")
assert op.get_openlineage_facets_on_complete(None) is None

0 comments on commit 2782f2f

Please sign in to comment.