Skip to content

Commit

Permalink
Deprecate LivyOperatorAsync (#1454)
Browse files Browse the repository at this point in the history
Deprecate LivyOperatorAsync and proxy it to its Airflow
OSS provider's counterpart

closes: #1421
  • Loading branch information
pankajkoti authored Jan 24, 2024
1 parent 3b136fb commit 12cc4a7
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 273 deletions.
17 changes: 7 additions & 10 deletions astronomer/providers/apache/livy/hooks/livy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module contains the Apache Livy hook async."""
import asyncio
import re
import warnings
from typing import Any, Dict, List, Optional, Sequence, Union

import aiohttp
Expand All @@ -15,16 +16,8 @@

class LivyHookAsync(HttpHookAsync, LoggingMixin):
"""
Hook for Apache Livy through the REST API using LivyHookAsync
:param livy_conn_id: reference to a pre-defined Livy Connection.
:param extra_options: Additional option can be passed when creating a request.
For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)``
:param extra_headers: A dictionary of headers passed to the HTTP request to livy.
.. seealso::
For more details refer to the Apache Livy API reference:
`Apache Livy API reference <https://livy.apache.org/docs/latest/rest-api.html>`_
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.apache.livy.hooks.livy.LivyHook` instead.
"""

TERMINAL_STATES = {
Expand All @@ -47,6 +40,10 @@ def __init__(
extra_options: Optional[Dict[str, Any]] = None,
extra_headers: Optional[Dict[str, Any]] = None,
) -> None:
warnings.warn(
"This class is deprecated and will be removed in 2.0.0."
"Use `airflow.providers.apache.livy.hooks.livy.LivyHook` instead."
)
super().__init__(http_conn_id=livy_conn_id)
self.extra_headers = extra_headers or {}
self.extra_options = extra_options or {}
Expand Down
98 changes: 13 additions & 85 deletions astronomer/providers/apache/livy/operators/livy.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,23 @@
"""This module contains the Apache Livy operator async."""
from typing import Any, Dict
from __future__ import annotations

from airflow.exceptions import AirflowException
from airflow.providers.apache.livy.operators.livy import BatchState, LivyOperator
import warnings
from typing import Any

from astronomer.providers.apache.livy.triggers.livy import LivyTrigger
from astronomer.providers.utils.typing_compat import Context
from airflow.providers.apache.livy.operators.livy import LivyOperator


class LivyOperatorAsync(LivyOperator):
"""
This operator wraps the Apache Livy batch REST API, allowing to submit a Spark
application to the underlying cluster asynchronously.
:param file: path of the file containing the application to execute (required).
:param class_name: name of the application Java/Spark main class.
:param args: application command line arguments.
:param jars: jars to be used in this sessions.
:param py_files: python files to be used in this session.
:param files: files to be used in this session.
:param driver_memory: amount of memory to use for the driver process.
:param driver_cores: number of cores to use for the driver process.
:param executor_memory: amount of memory to use per executor process.
:param executor_cores: number of cores to use for each executor.
:param num_executors: number of executors to launch for this session.
:param archives: archives to be used in this session.
:param queue: name of the YARN queue to which the application is submitted.
:param name: name of this session.
:param conf: Spark configuration properties.
:param proxy_user: user to impersonate when running the job.
:param livy_conn_id: reference to a pre-defined Livy Connection.
:param polling_interval: time in seconds between polling for job completion. If poll_interval=0, in that case
return the batch_id and if polling_interval > 0, poll the livy job for termination in the polling interval
defined.
:param extra_options: Additional option can be passed when creating a request.
For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)``
:param extra_headers: A dictionary of headers passed to the HTTP request to livy.
:param retry_args: Arguments which define the retry behaviour.
See Tenacity documentation at https://github.com/jd/tenacity
This class is deprecated.
Use :class: `~airflow.providers.apache.livy.operators.livy.LivyOperator` instead
and set `deferrable` param to `True` instead.
"""

def execute(self, context: Context) -> Any:
"""
Airflow runs this method on the worker and defers using the trigger.
Submit the job and get the job_id using which we defer and poll in trigger
"""
self._batch_id = self.get_hook().post_batch(**self.spark_params)
self.log.info("Generated batch-id is %s", self._batch_id)

hook = self.get_hook()
state = hook.get_batch_state(self._batch_id, retry_args=self.retry_args)
self.log.debug("Batch with id %s is in state: %s", self._batch_id, state.value)
if state not in hook.TERMINAL_STATES:
self.defer(
timeout=self.execution_timeout,
trigger=LivyTrigger(
batch_id=self._batch_id,
spark_params=self.spark_params,
livy_conn_id=self._livy_conn_id,
polling_interval=self._polling_interval,
extra_options=self._extra_options,
extra_headers=self._extra_headers,
),
method_name="execute_complete",
)
else:
self.log.info("Batch with id %s terminated with state: %s", self._batch_id, state.value)
hook.dump_batch_logs(self._batch_id)
if state != BatchState.SUCCESS:
raise AirflowException(f"Batch {self._batch_id} did not succeed")

context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(self._batch_id)["appId"])
return self._batch_id

def execute_complete(self, context: Context, event: Dict[str, Any]) -> Any:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
# dump the logs from livy to worker through triggerer.
if event.get("log_lines", None) is not None:
for log_line in event["log_lines"]:
self.log.info(log_line)

if event["status"] == "error":
raise AirflowException(event["response"])
self.log.info(
"%s completed with response %s",
self.task_id,
event["response"],
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
"This class is deprecated. "
"Use `airflow.providers.apache.livy.operators.livy.LivyOperator` "
"and set `deferrable` param to `True` instead.",
)
context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(event["batch_id"])["appId"])
return event["batch_id"]
super().__init__(*args, deferrable=True, **kwargs)
22 changes: 7 additions & 15 deletions astronomer/providers/apache/livy/triggers/livy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This module contains the Apache Livy Trigger."""
import asyncio
import warnings
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union

from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand All @@ -9,21 +10,8 @@

class LivyTrigger(BaseTrigger):
"""
Check for the state of a previously submitted job with batch_id
:param batch_id: Batch job id
:param spark_params: Spark parameters; for example,
spark_params = {"file": "test/pi.py", "class_name": "org.apache.spark.examples.SparkPi",
"args": ["/usr/lib/spark/bin/run-example", "SparkPi", "10"],"jars": "command-runner.jar",
"driver_cores": 1, "executor_cores": 4,"num_executors": 1}
:param livy_conn_id: reference to a pre-defined Livy Connection.
:param polling_interval: time in seconds between polling for job completion. If poll_interval=0, in that case
return the batch_id and if polling_interval > 0, poll the livy job for termination in the polling interval
defined.
:param extra_options: A dictionary of options, where key is string and value
depends on the option that's being modified.
:param extra_headers: A dictionary of headers passed to the HTTP request to livy.
:param livy_hook_async: LivyHookAsync object
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.apache.livy.triggers.livy.LivyTrigger` instead.
"""

def __init__(
Expand All @@ -36,6 +24,10 @@ def __init__(
extra_headers: Optional[Dict[str, Any]] = None,
livy_hook_async: Optional[LivyHookAsync] = None,
):
warnings.warn(
"This class is deprecated. "
"Use `airflow.providers.apache.livy.triggers.livy.LivyTrigger` instead.",
)
super().__init__()
self._batch_id = batch_id
self.spark_params = spark_params
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ apache.hive =
apache-airflow-providers-apache-hive>=6.1.5
impyla
apache.livy =
apache-airflow-providers-apache-livy
apache-airflow-providers-apache-livy>=3.7.1
paramiko
cncf.kubernetes =
apache-airflow-providers-cncf-kubernetes>=4
Expand Down Expand Up @@ -120,7 +120,7 @@ all =
aiobotocore>=2.1.1
apache-airflow-providers-amazon>=8.16.0
apache-airflow-providers-apache-hive>=6.1.5
apache-airflow-providers-apache-livy
apache-airflow-providers-apache-livy>=3.7.1
apache-airflow-providers-cncf-kubernetes>=4
apache-airflow-providers-databricks>=2.2.0
apache-airflow-providers-google>=8.1.0
Expand Down
166 changes: 5 additions & 161 deletions tests/apache/livy/operators/test_livy.py
Original file line number Diff line number Diff line change
@@ -1,174 +1,18 @@
from unittest.mock import MagicMock, patch
from __future__ import annotations

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.apache.livy.hooks.livy import BatchState
from airflow.utils import timezone
from airflow.providers.apache.livy.operators.livy import LivyOperator

from astronomer.providers.apache.livy.operators.livy import LivyOperatorAsync
from astronomer.providers.apache.livy.triggers.livy import LivyTrigger

DEFAULT_DATE = timezone.datetime(2017, 1, 1)
mock_livy_client = MagicMock()

BATCH_ID = 100
LOG_RESPONSE = {"total": 3, "log": ["first_line", "second_line", "third_line"]}


class TestLivyOperatorAsync:
@pytest.fixture()
@patch(
"astronomer.providers.apache.livy.hooks.livy.LivyHookAsync.dump_batch_logs",
return_value=None,
)
@patch("astronomer.providers.apache.livy.hooks.livy.LivyHookAsync.get_batch_state")
async def test_poll_for_termination(self, mock_livy, mock_dump_logs, dag):
state_list = 2 * [BatchState.RUNNING] + [BatchState.SUCCESS]

def side_effect(_, retry_args):
if state_list:
return state_list.pop(0)
# fail if does not stop right before
raise AssertionError()

mock_livy.side_effect = side_effect

task = LivyOperatorAsync(file="sparkapp", polling_interval=1, dag=dag, task_id="livy_example")
task._livy_hook = task.get_hook()
task.poll_for_termination(BATCH_ID)

mock_livy.assert_called_with(BATCH_ID, retry_args=None)
mock_dump_logs.assert_called_with(BATCH_ID)
assert mock_livy.call_count == 3

@pytest.mark.parametrize(
"mock_state",
(
BatchState.NOT_STARTED,
BatchState.STARTING,
BatchState.RUNNING,
BatchState.IDLE,
BatchState.SHUTTING_DOWN,
),
)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state")
def test_livy_operator_async(self, mock_get_batch_state, mock_post, mock_state, dag):
mock_get_batch_state.retun_value = mock_state
task = LivyOperatorAsync(
livy_conn_id="livyunittest",
file="sparkapp",
polling_interval=1,
dag=dag,
task_id="livy_example",
)

with pytest.raises(TaskDeferred) as exc:
task.execute({})

assert isinstance(exc.value.trigger, LivyTrigger), "Trigger is not a LivyTrigger"

@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs",
return_value=None,
)
@patch("astronomer.providers.apache.livy.operators.livy.LivyOperatorAsync.defer")
@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value={"appId": BATCH_ID}
)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state",
return_value=BatchState.SUCCESS,
)
def test_livy_operator_async_finish_before_deferred_success(
self, mock_get_batch_state, mock_post, mock_get, mock_defer, mock_dump_logs, dag
):
def test_init(self):
task = LivyOperatorAsync(
livy_conn_id="livyunittest",
file="sparkapp",
polling_interval=1,
dag=dag,
task_id="livy_example",
)
assert task.execute(context={"ti": MagicMock()}) == BATCH_ID
assert not mock_defer.called

@pytest.mark.parametrize(
"mock_state",
(
BatchState.ERROR,
BatchState.DEAD,
BatchState.KILLED,
),
)
@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs",
return_value=None,
)
@patch("astronomer.providers.apache.livy.operators.livy.LivyOperatorAsync.defer")
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state")
def test_livy_operator_async_finish_before_deferred_not_success(
self, mock_get_batch_state, mock_post, mock_defer, mock_dump_logs, mock_state, dag
):
mock_get_batch_state.return_value = mock_state

task = LivyOperatorAsync(
livy_conn_id="livyunittest",
file="sparkapp",
polling_interval=1,
dag=dag,
task_id="livy_example",
)
with pytest.raises(AirflowException):
task.execute({})
assert not mock_defer.called

@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value={"appId": BATCH_ID}
)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
def test_livy_operator_async_execute_complete_success(self, mock_post, mock_get, dag):
"""Asserts that a task is completed with success status."""
task = LivyOperatorAsync(
livy_conn_id="livyunittest",
file="sparkapp",
polling_interval=1,
dag=dag,
task_id="livy_example",
)
assert (
task.execute_complete(
context={"ti": MagicMock()},
event={
"status": "success",
"log_lines": None,
"batch_id": BATCH_ID,
"response": "mock success",
},
)
is BATCH_ID
)

@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
def test_livy_operator_async_execute_complete_error(self, mock_post, dag):
"""Asserts that a task is completed with success status."""

task = LivyOperatorAsync(
livy_conn_id="livyunittest",
file="sparkapp",
polling_interval=1,
dag=dag,
task_id="livy_example",
)
with pytest.raises(AirflowException):
task.execute_complete(
context={},
event={
"status": "error",
"log_lines": ["mock log"],
"batch_id": BATCH_ID,
"response": "mock error",
},
)
assert isinstance(task, LivyOperator)
assert task.deferrable is True

0 comments on commit 12cc4a7

Please sign in to comment.