Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deprecate KubernetesPodOperatorAsync #1465

Merged
merged 5 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"-cx",
(
"i=0; "
"while [ $i -ne 30 ]; "
"while [ $i -ne 150 ]; "
"do i=$(($i+1)); "
"echo $i; "
"sleep 1; "
Expand All @@ -55,6 +55,7 @@
),
],
do_xcom_push=True,
logging_interval=5,
)
# [END howto_operator_kubernetes_pod_async]

Expand Down
145 changes: 15 additions & 130 deletions astronomer/providers/cncf/kubernetes/operators/kubernetes_pod.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
from __future__ import annotations

import warnings
from typing import Any, Dict, Optional
from typing import Any

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import (
KubernetesPodOperator,
)
from kubernetes.client import models as k8s
from pendulum import DateTime

from astronomer.providers.cncf.kubernetes.triggers.wait_container import (
PodLaunchTimeoutException,
WaitContainerTrigger,
)
from astronomer.providers.utils.typing_compat import Context


class PodNotFoundException(AirflowException):
Expand All @@ -21,129 +15,20 @@ class PodNotFoundException(AirflowException):

class KubernetesPodOperatorAsync(KubernetesPodOperator):
"""
Async (deferring) version of KubernetesPodOperator
This class is deprecated.

.. warning::
By default, logs will not be available in the Airflow Webserver until the task completes. However,
you can configure ``KubernetesPodOperatorAsync`` to periodically resume and fetch logs. This behavior
is controlled by param ``logging_interval``.

:param poll_interval: interval in seconds to sleep between checking pod status
:param logging_interval: max time in seconds that task should be in deferred state before
resuming to fetch latest logs. If ``None``, then the task will remain in deferred state until pod
is done, and no logs will be visible until that time.
Please use :class: `~airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator`
and set `deferrable` param to `True` instead.
"""

def __init__(self, *, poll_interval: int = 5, logging_interval: Optional[int] = None, **kwargs: Any):
self.poll_interval = poll_interval
self.logging_interval = logging_interval
super().__init__(**kwargs)

@staticmethod
def raise_for_trigger_status(event: Dict[str, Any]) -> None:
"""Raise exception if pod is not in expected state."""
if event["status"] == "error":
error_type = event["error_type"]
description = event["description"]
if error_type == "PodLaunchTimeoutException":
raise PodLaunchTimeoutException(description)
else:
raise AirflowException(description)

def defer(self, last_log_time: Optional[DateTime] = None, **kwargs: Any) -> None:
"""Defers to ``WaitContainerTrigger`` optionally with last log time."""
if kwargs:
raise ValueError(
f"Received keyword arguments {list(kwargs.keys())} but "
f"they are not used in this implementation of `defer`."
)
super().defer(
trigger=WaitContainerTrigger(
kubernetes_conn_id=self.kubernetes_conn_id,
hook_params={
"cluster_context": self.cluster_context,
"config_file": self.config_file,
"in_cluster": self.in_cluster,
},
pod_name=self.pod.metadata.name,
container_name=self.BASE_CONTAINER_NAME,
pod_namespace=self.pod.metadata.namespace,
pending_phase_timeout=self.startup_timeout_seconds,
poll_interval=self.poll_interval,
logging_interval=self.logging_interval,
last_log_time=last_log_time,
),
method_name=self.trigger_reentry.__name__,
)

def execute(self, context: Context) -> None: # noqa: D102
self.pod_request_obj = self.build_pod_request_obj(context)
self.pod: k8s.V1Pod = self.get_or_create_pod(self.pod_request_obj, context)
self.defer()

def execute_complete(self, context: Context, event: Dict[str, Any]) -> Any: # type: ignore[override]
"""Deprecated; replaced by trigger_reentry."""
def __init__(self, **kwargs: Any):
warnings.warn(
"Method `execute_complete` is deprecated and replaced with method `trigger_reentry`.",
(
"This module is deprecated."
"Please use `airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator`"
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
self.trigger_reentry(context=context, event=event)

def trigger_reentry(self, context: Context, event: Dict[str, Any]) -> Any:
"""
Point of re-entry from trigger.

If ``logging_interval`` is None, then at this point the pod should be done and we'll just fetch
the logs and exit.

If ``logging_interval`` is not None, it could be that the pod is still running and we'll just
grab the latest logs and defer back to the trigger again.
"""
remote_pod = None
try:
self.pod_request_obj = self.build_pod_request_obj(context)
self.pod = self.find_pod(
namespace=self.namespace or self.pod_request_obj.metadata.namespace,
context=context,
)

# we try to find pod before possibly raising so that on_kill will have `pod` attr
self.raise_for_trigger_status(event)

if not self.pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")

if self.get_logs:
last_log_time = event and event.get("last_log_time")
if last_log_time:
self.log.info("Resuming logs read from time %r", last_log_time)
pod_log_status = self.pod_manager.fetch_container_logs(
pod=self.pod,
container_name=self.BASE_CONTAINER_NAME,
follow=self.logging_interval is None,
since_time=last_log_time,
)
if pod_log_status.running:
self.log.info("Container still running; deferring again.")
self.defer(pod_log_status.last_log_time)

if self.do_xcom_push:
result = self.extract_xcom(pod=self.pod)
remote_pod = self.pod_manager.await_pod_completion(self.pod)
except TaskDeferred:
raise
except Exception:
self.cleanup(
pod=self.pod or self.pod_request_obj,
remote_pod=remote_pod,
)
raise
self.cleanup(
pod=self.pod or self.pod_request_obj,
remote_pod=remote_pod,
)
ti = context["ti"]
ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
if self.do_xcom_push:
return result
super().__init__(deferrable=True, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need to pass logging_interval

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's present in the kwargs, it will be passed is. The default previous value here for logging_interval was None, and the upstream has default None too, so no additional care needs to be taken.

24 changes: 11 additions & 13 deletions astronomer/providers/cncf/kubernetes/triggers/wait_container.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import traceback
import warnings
from datetime import timedelta
from typing import Any, AsyncIterator, Dict, Optional, Tuple

Expand All @@ -22,19 +23,8 @@ class PodLaunchTimeoutException(AirflowException):

class WaitContainerTrigger(BaseTrigger):
"""
First, waits for pod ``pod_name`` to reach running state within ``pending_phase_timeout``.
Next, waits for ``container_name`` to reach a terminal state.

:param kubernetes_conn_id: Airflow connection ID to use
:param hook_params: kwargs for hook
:param container_name: container to wait for
:param pod_name: name of pod to monitor
:param pod_namespace: pod namespace
:param pending_phase_timeout: max time in seconds to wait for pod to leave pending phase
:param poll_interval: number of seconds between reading pod state
:param logging_interval: number of seconds to wait before kicking it back to
the operator to print latest logs. If ``None`` will wait until container done.
:param last_log_time: where to resume logs from
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.cncf.kubernetes.triggers.pod.KubernetesPodTrigger` instead
"""

def __init__(
Expand All @@ -50,6 +40,14 @@ def __init__(
logging_interval: Optional[int] = None,
last_log_time: Optional[DateTime] = None,
):
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.cncf.kubernetes.triggers.pod.KubernetesPodTrigger`"
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.kubernetes_conn_id = kubernetes_conn_id
self.hook_params = hook_params
Expand Down
162 changes: 7 additions & 155 deletions tests/cncf/kubernetes/operators/test_kubernetes_pod.py
Original file line number Diff line number Diff line change
@@ -1,163 +1,15 @@
from unittest import mock
from unittest.mock import MagicMock

import pytest
from airflow.exceptions import TaskDeferred
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodLoggingStatus
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator

from astronomer.providers.cncf.kubernetes.operators.kubernetes_pod import (
KubernetesPodOperatorAsync,
PodNotFoundException,
)
from astronomer.providers.cncf.kubernetes.triggers.wait_container import (
PodLaunchTimeoutException,
)
from tests.utils.airflow_util import create_context

KUBE_POD_MOD = "astronomer.providers.cncf.kubernetes.operators.kubernetes_pod"


class TestKubernetesPodOperatorAsync:
def test_raise_for_trigger_status_pending_timeout(self):
"""Assert trigger raise exception in case of timeout"""
with pytest.raises(PodLaunchTimeoutException):
KubernetesPodOperatorAsync.raise_for_trigger_status(
{
"status": "error",
"error_type": "PodLaunchTimeoutException",
"description": "any message",
}
)

def test_raise_for_trigger_status_done(self):
"""Assert trigger don't raise exception in case of status is done"""
assert KubernetesPodOperatorAsync.raise_for_trigger_status({"status": "done"}) is None

@mock.patch("airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.client")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.cleanup")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.raise_for_trigger_status")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.find_pod")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs")
@mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook._get_default_client")
def test_get_logs_running(
self,
mock_get_default_client,
fetch_container_logs,
await_pod_completion,
find_pod,
raise_for_trigger_status,
cleanup,
mock_client,
):
"""When logs fetch exits with status running, raise task deferred"""
pod = MagicMock()
find_pod.return_value = pod
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
mock_client.return_value = {}
context = create_context(op)
await_pod_completion.return_value = None
fetch_container_logs.return_value = PodLoggingStatus(True, None)
with pytest.raises(TaskDeferred):
op.trigger_reentry(context, None)
fetch_container_logs.is_called_with(pod, "base")

@mock.patch("airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.client")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.cleanup")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.raise_for_trigger_status")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.find_pod")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs")
@mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook._get_default_client")
def test_get_logs_not_running(
self,
mock_get_default_client,
fetch_container_logs,
await_pod_completion,
find_pod,
raise_for_trigger_status,
cleanup,
mock_client,
):
pod = MagicMock()
find_pod.return_value = pod
mock_client.return_value = {}
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
context = create_context(op)
await_pod_completion.return_value = None
fetch_container_logs.return_value = PodLoggingStatus(False, None)
op.trigger_reentry(context, None)
fetch_container_logs.is_called_with(pod, "base")

@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.cleanup")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.raise_for_trigger_status")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.find_pod")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs")
@mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook._get_default_client")
def test_no_pod(
self,
mock_get_default_client,
fetch_container_logs,
await_pod_completion,
find_pod,
raise_for_trigger_status,
cleanup,
):
"""Assert if pod not found then raise exception"""
find_pod.return_value = None
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
context = create_context(op)
with pytest.raises(PodNotFoundException):
op.trigger_reentry(context, None)

@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.cleanup")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.find_pod")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs")
@mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook._get_default_client")
def test_trigger_error(
self,
mock_get_default_client,
fetch_container_logs,
await_pod_completion,
find_pod,
cleanup,
):
"""Assert that trigger_reentry raise exception in case of error"""
find_pod.return_value = MagicMock()
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
with pytest.raises(PodLaunchTimeoutException):
context = create_context(op)
op.trigger_reentry(
context,
{
"status": "error",
"error_type": "PodLaunchTimeoutException",
"description": "any message",
},
)

def test_defer_with_kwargs(self):
"""Assert that with kwargs throw exception"""
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
with pytest.raises(ValueError):
op.defer(kwargs={"timeout": 10})

@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.build_pod_request_obj")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.get_or_create_pod")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.defer")
def test_execute(self, mock_defer, mock_get_or_create_pod, mock_build_pod_request_obj):
"""Assert that execute succeeded"""
mock_get_or_create_pod.return_value = {}
mock_build_pod_request_obj.return_value = {}
mock_defer.return_value = {}
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
assert op.execute(context=create_context(op)) is None
def test_init(self):
task = KubernetesPodOperatorAsync(
task_id="test_task", name="test-pod", get_logs=True, logging_interval=5
)

@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.trigger_reentry")
def test_execute_complete(self, mock_trigger_reentry):
"""Assert that execute_complete succeeded"""
mock_trigger_reentry.return_value = {}
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
assert op.execute_complete(context=create_context(op), event={}) is None
assert isinstance(task, KubernetesPodOperator)
assert task.deferrable is True
Loading