Skip to content

Commit

Permalink
Add on_warning_callback to DbtSourceKubernetesOperator and refact…
Browse files Browse the repository at this point in the history
…or previous operators (#1501)

It seems that `DbtSourceKubernetesOperator` is missing some logic for
handling the `on_warning_callback`. This logic was added for
`DbtTestKubernetesOperator` in #673 , I tried to apply the same logic
also for the source operator.
 
## Related Issue(s)

closes #1500
  • Loading branch information
LuigiCerone authored Feb 3, 2025
1 parent 9c175f6 commit 3eb67bd
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 30 deletions.
68 changes: 40 additions & 28 deletions cosmos/operators/kubernetes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from abc import ABC
from os import PathLike
from typing import Any, Callable, Sequence

Expand Down Expand Up @@ -136,31 +137,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtSourceKubernetesOperator(DbtSourceMixin, DbtKubernetesBaseOperator):
"""
Executes a dbt source freshness command.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtRunKubernetesOperator(DbtRunMixin, DbtKubernetesBaseOperator):
"""
Executes a dbt core run command.
"""

template_fields: Sequence[str] = DbtKubernetesBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator]

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtTestKubernetesOperator(DbtTestMixin, DbtKubernetesBaseOperator):
"""
Executes a dbt core test command.
"""

class DbtWarningKubernetesOperator(DbtKubernetesBaseOperator, ABC):
def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: Any) -> None:
if not on_warning_callback:
super().__init__(**kwargs)
Expand All @@ -181,7 +158,7 @@ def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwar
kwargs["is_delete_operator_pod"] = False
kwargs["on_finish_action"] = OnFinishAction.KEEP_POD

# Add an additional callback to both success and failure callbacks.
# Add a callback to both success and failure callbacks.
# In case of success, check for a warning in the logs and clean up the pod.
self.on_success_callback = kwargs.get("on_success_callback", None) or []
if isinstance(self.on_success_callback, list):
Expand All @@ -208,7 +185,10 @@ def _handle_warnings(self, context: Context) -> None:
"""
if not (
isinstance(context["task_instance"], TaskInstance)
and isinstance(context["task_instance"].task, DbtTestKubernetesOperator)
and (
isinstance(context["task_instance"].task, DbtTestKubernetesOperator)
or isinstance(context["task_instance"].task, DbtSourceKubernetesOperator)
)
):
return
task = context["task_instance"].task
Expand Down Expand Up @@ -243,7 +223,10 @@ def _cleanup_pod(self, context: Context) -> None:
"""
if not (
isinstance(context["task_instance"], TaskInstance)
and isinstance(context["task_instance"].task, DbtTestKubernetesOperator)
and (
isinstance(context["task_instance"].task, DbtTestKubernetesOperator)
or isinstance(context["task_instance"].task, DbtSourceKubernetesOperator)
)
):
return
task = context["task_instance"].task
Expand All @@ -252,6 +235,35 @@ def _cleanup_pod(self, context: Context) -> None:
task.cleanup(pod=task.pod, remote_pod=task.remote_pod)


class DbtTestKubernetesOperator(DbtTestMixin, DbtWarningKubernetesOperator):
"""
Executes a dbt core test command.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtSourceKubernetesOperator(DbtSourceMixin, DbtWarningKubernetesOperator):
"""
Executes a dbt source freshness command.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtRunKubernetesOperator(DbtRunMixin, DbtKubernetesBaseOperator):
"""
Executes a dbt core run command.
"""

template_fields: Sequence[str] = DbtKubernetesBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator]

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtRunOperationKubernetesOperator(DbtRunOperationMixin, DbtKubernetesBaseOperator):
"""
Executes a dbt core run-operation command.
Expand Down
84 changes: 82 additions & 2 deletions tests/operators/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DbtLSKubernetesOperator,
DbtRunKubernetesOperator,
DbtSeedKubernetesOperator,
DbtSourceKubernetesOperator,
DbtTestKubernetesOperator,
)

Expand Down Expand Up @@ -118,11 +119,9 @@ def test_dbt_kubernetes_operator_get_env(p_context_to_airflow_vars: MagicMock, b
"no_version_check": True,
}


if version.parse(airflow_version) == version.parse("2.4"):
base_kwargs["name"] = "some-pod-name"


result_map = {
"ls": DbtLSKubernetesOperator(**base_kwargs),
"run": DbtRunKubernetesOperator(**base_kwargs),
Expand Down Expand Up @@ -208,6 +207,62 @@ def test_dbt_test_kubernetes_operator_constructor(additional_kwargs, expected_re
assert test_operator.on_finish_action_original == OnFinishAction(expected_results[3])


@pytest.mark.parametrize(
"additional_kwargs,expected_results",
[
({"on_success_callback": None, "is_delete_operator_pod": True}, (1, 1, True, "delete_pod")),
(
{"on_success_callback": (lambda **kwargs: None), "is_delete_operator_pod": False},
(2, 1, False, "keep_pod"),
),
(
{"on_success_callback": [(lambda **kwargs: None), (lambda **kwargs: None)], "is_delete_operator_pod": None},
(3, 1, True, "delete_pod"),
),
(
{"on_failure_callback": None, "is_delete_operator_pod": True, "on_finish_action": "keep_pod"},
(1, 1, True, "delete_pod"),
),
(
{
"on_failure_callback": (lambda **kwargs: None),
"is_delete_operator_pod": None,
"on_finish_action": "delete_pod",
},
(1, 2, True, "delete_pod"),
),
(
{
"on_failure_callback": [(lambda **kwargs: None), (lambda **kwargs: None)],
"is_delete_operator_pod": None,
"on_finish_action": "delete_succeeded_pod",
},
(1, 3, False, "delete_succeeded_pod"),
),
({"is_delete_operator_pod": None, "on_finish_action": "keep_pod"}, (1, 1, False, "keep_pod")),
({}, (1, 1, True, "delete_pod")),
],
)
@pytest.mark.skipif(
not module_available, reason="Kubernetes module `airflow.providers.cncf.kubernetes.utils.pod_manager` not available"
)
def test_dbt_source_kubernetes_operator_constructor(additional_kwargs, expected_results):
source_operator = DbtSourceKubernetesOperator(
on_warning_callback=(lambda **kwargs: None), **additional_kwargs, **base_kwargs
)

print(additional_kwargs, source_operator.__dict__)

assert isinstance(source_operator.on_success_callback, list)
assert isinstance(source_operator.on_failure_callback, list)
assert source_operator._handle_warnings in source_operator.on_success_callback
assert source_operator._cleanup_pod in source_operator.on_failure_callback
assert len(source_operator.on_success_callback) == expected_results[0]
assert len(source_operator.on_failure_callback) == expected_results[1]
assert source_operator.is_delete_operator_pod_original == expected_results[2]
assert source_operator.on_finish_action_original == OnFinishAction(expected_results[3])


class FakePodManager:
def read_pod_logs(self, pod, container):
assert pod == "pod"
Expand Down Expand Up @@ -259,6 +314,31 @@ def cleanup(pod: str, remote_pod: str):
test_operator._handle_warnings(context)


@pytest.mark.skipif(
not module_available, reason="Kubernetes module `airflow.providers.cncf.kubernetes.utils.pod_manager` not available"
)
def test_dbt_source_kubernetes_operator_handle_warnings_and_cleanup_pod():
def on_warning_callback(context: Context):
assert context["test_names"] == ["dbt_utils_accepted_range_table_col__12__0"]
assert context["test_results"] == ["Got 252 results, configured to warn if >0"]

def cleanup(pod: str, remote_pod: str):
assert pod == remote_pod

test_operator = DbtSourceKubernetesOperator(
is_delete_operator_pod=True, on_warning_callback=on_warning_callback, **base_kwargs
)
task_instance = TaskInstance(test_operator)
task_instance.task.pod_manager = FakePodManager()
task_instance.task.pod = task_instance.task.remote_pod = "pod"
task_instance.task.cleanup = cleanup

context = Context()
context_merge(context, task_instance=task_instance)

test_operator._handle_warnings(context)


def test_created_pod():
ls_kwargs = {"env_vars": {"FOO": "BAR"}, "namespace": "foo", "append_env": False}
ls_kwargs.update(base_kwargs)
Expand Down

0 comments on commit 3eb67bd

Please sign in to comment.