-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Deprecate LivyOperatorAsync and proxy it to its Airflow OSS provider's counterpart closes: #1421
- Loading branch information
1 parent
3b136fb
commit 12cc4a7
Showing
5 changed files
with
34 additions
and
273 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,95 +1,23 @@ | ||
"""This module contains the Apache Livy operator async.""" | ||
from typing import Any, Dict | ||
from __future__ import annotations | ||
|
||
from airflow.exceptions import AirflowException | ||
from airflow.providers.apache.livy.operators.livy import BatchState, LivyOperator | ||
import warnings | ||
from typing import Any | ||
|
||
from astronomer.providers.apache.livy.triggers.livy import LivyTrigger | ||
from astronomer.providers.utils.typing_compat import Context | ||
from airflow.providers.apache.livy.operators.livy import LivyOperator | ||
|
||
|
||
class LivyOperatorAsync(LivyOperator): | ||
""" | ||
This operator wraps the Apache Livy batch REST API, allowing to submit a Spark | ||
application to the underlying cluster asynchronously. | ||
:param file: path of the file containing the application to execute (required). | ||
:param class_name: name of the application Java/Spark main class. | ||
:param args: application command line arguments. | ||
:param jars: jars to be used in this sessions. | ||
:param py_files: python files to be used in this session. | ||
:param files: files to be used in this session. | ||
:param driver_memory: amount of memory to use for the driver process. | ||
:param driver_cores: number of cores to use for the driver process. | ||
:param executor_memory: amount of memory to use per executor process. | ||
:param executor_cores: number of cores to use for each executor. | ||
:param num_executors: number of executors to launch for this session. | ||
:param archives: archives to be used in this session. | ||
:param queue: name of the YARN queue to which the application is submitted. | ||
:param name: name of this session. | ||
:param conf: Spark configuration properties. | ||
:param proxy_user: user to impersonate when running the job. | ||
:param livy_conn_id: reference to a pre-defined Livy Connection. | ||
:param polling_interval: time in seconds between polling for job completion. If poll_interval=0, in that case | ||
return the batch_id and if polling_interval > 0, poll the livy job for termination in the polling interval | ||
defined. | ||
:param extra_options: Additional option can be passed when creating a request. | ||
For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)`` | ||
:param extra_headers: A dictionary of headers passed to the HTTP request to livy. | ||
:param retry_args: Arguments which define the retry behaviour. | ||
See Tenacity documentation at https://github.com/jd/tenacity | ||
This class is deprecated. | ||
Use :class: `~airflow.providers.apache.livy.operators.livy.LivyOperator` instead | ||
and set `deferrable` param to `True` instead. | ||
""" | ||
|
||
def execute(self, context: Context) -> Any: | ||
""" | ||
Airflow runs this method on the worker and defers using the trigger. | ||
Submit the job and get the job_id using which we defer and poll in trigger | ||
""" | ||
self._batch_id = self.get_hook().post_batch(**self.spark_params) | ||
self.log.info("Generated batch-id is %s", self._batch_id) | ||
|
||
hook = self.get_hook() | ||
state = hook.get_batch_state(self._batch_id, retry_args=self.retry_args) | ||
self.log.debug("Batch with id %s is in state: %s", self._batch_id, state.value) | ||
if state not in hook.TERMINAL_STATES: | ||
self.defer( | ||
timeout=self.execution_timeout, | ||
trigger=LivyTrigger( | ||
batch_id=self._batch_id, | ||
spark_params=self.spark_params, | ||
livy_conn_id=self._livy_conn_id, | ||
polling_interval=self._polling_interval, | ||
extra_options=self._extra_options, | ||
extra_headers=self._extra_headers, | ||
), | ||
method_name="execute_complete", | ||
) | ||
else: | ||
self.log.info("Batch with id %s terminated with state: %s", self._batch_id, state.value) | ||
hook.dump_batch_logs(self._batch_id) | ||
if state != BatchState.SUCCESS: | ||
raise AirflowException(f"Batch {self._batch_id} did not succeed") | ||
|
||
context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(self._batch_id)["appId"]) | ||
return self._batch_id | ||
|
||
def execute_complete(self, context: Context, event: Dict[str, Any]) -> Any: | ||
""" | ||
Callback for when the trigger fires - returns immediately. | ||
Relies on trigger to throw an exception, otherwise it assumes execution was | ||
successful. | ||
""" | ||
# dump the logs from livy to worker through triggerer. | ||
if event.get("log_lines", None) is not None: | ||
for log_line in event["log_lines"]: | ||
self.log.info(log_line) | ||
|
||
if event["status"] == "error": | ||
raise AirflowException(event["response"]) | ||
self.log.info( | ||
"%s completed with response %s", | ||
self.task_id, | ||
event["response"], | ||
def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
warnings.warn( | ||
"This class is deprecated. " | ||
"Use `airflow.providers.apache.livy.operators.livy.LivyOperator` " | ||
"and set `deferrable` param to `True` instead.", | ||
) | ||
context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(event["batch_id"])["appId"]) | ||
return event["batch_id"] | ||
super().__init__(*args, deferrable=True, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,174 +1,18 @@ | ||
from unittest.mock import MagicMock, patch | ||
from __future__ import annotations | ||
|
||
import pytest | ||
from airflow.exceptions import AirflowException, TaskDeferred | ||
from airflow.providers.apache.livy.hooks.livy import BatchState | ||
from airflow.utils import timezone | ||
from airflow.providers.apache.livy.operators.livy import LivyOperator | ||
|
||
from astronomer.providers.apache.livy.operators.livy import LivyOperatorAsync | ||
from astronomer.providers.apache.livy.triggers.livy import LivyTrigger | ||
|
||
DEFAULT_DATE = timezone.datetime(2017, 1, 1) | ||
mock_livy_client = MagicMock() | ||
|
||
BATCH_ID = 100 | ||
LOG_RESPONSE = {"total": 3, "log": ["first_line", "second_line", "third_line"]} | ||
|
||
|
||
class TestLivyOperatorAsync: | ||
@pytest.fixture() | ||
@patch( | ||
"astronomer.providers.apache.livy.hooks.livy.LivyHookAsync.dump_batch_logs", | ||
return_value=None, | ||
) | ||
@patch("astronomer.providers.apache.livy.hooks.livy.LivyHookAsync.get_batch_state") | ||
async def test_poll_for_termination(self, mock_livy, mock_dump_logs, dag): | ||
state_list = 2 * [BatchState.RUNNING] + [BatchState.SUCCESS] | ||
|
||
def side_effect(_, retry_args): | ||
if state_list: | ||
return state_list.pop(0) | ||
# fail if does not stop right before | ||
raise AssertionError() | ||
|
||
mock_livy.side_effect = side_effect | ||
|
||
task = LivyOperatorAsync(file="sparkapp", polling_interval=1, dag=dag, task_id="livy_example") | ||
task._livy_hook = task.get_hook() | ||
task.poll_for_termination(BATCH_ID) | ||
|
||
mock_livy.assert_called_with(BATCH_ID, retry_args=None) | ||
mock_dump_logs.assert_called_with(BATCH_ID) | ||
assert mock_livy.call_count == 3 | ||
|
||
@pytest.mark.parametrize( | ||
"mock_state", | ||
( | ||
BatchState.NOT_STARTED, | ||
BatchState.STARTING, | ||
BatchState.RUNNING, | ||
BatchState.IDLE, | ||
BatchState.SHUTTING_DOWN, | ||
), | ||
) | ||
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) | ||
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state") | ||
def test_livy_operator_async(self, mock_get_batch_state, mock_post, mock_state, dag): | ||
mock_get_batch_state.retun_value = mock_state | ||
task = LivyOperatorAsync( | ||
livy_conn_id="livyunittest", | ||
file="sparkapp", | ||
polling_interval=1, | ||
dag=dag, | ||
task_id="livy_example", | ||
) | ||
|
||
with pytest.raises(TaskDeferred) as exc: | ||
task.execute({}) | ||
|
||
assert isinstance(exc.value.trigger, LivyTrigger), "Trigger is not a LivyTrigger" | ||
|
||
@patch( | ||
"airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", | ||
return_value=None, | ||
) | ||
@patch("astronomer.providers.apache.livy.operators.livy.LivyOperatorAsync.defer") | ||
@patch( | ||
"airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value={"appId": BATCH_ID} | ||
) | ||
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) | ||
@patch( | ||
"airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state", | ||
return_value=BatchState.SUCCESS, | ||
) | ||
def test_livy_operator_async_finish_before_deferred_success( | ||
self, mock_get_batch_state, mock_post, mock_get, mock_defer, mock_dump_logs, dag | ||
): | ||
def test_init(self): | ||
task = LivyOperatorAsync( | ||
livy_conn_id="livyunittest", | ||
file="sparkapp", | ||
polling_interval=1, | ||
dag=dag, | ||
task_id="livy_example", | ||
) | ||
assert task.execute(context={"ti": MagicMock()}) == BATCH_ID | ||
assert not mock_defer.called | ||
|
||
@pytest.mark.parametrize( | ||
"mock_state", | ||
( | ||
BatchState.ERROR, | ||
BatchState.DEAD, | ||
BatchState.KILLED, | ||
), | ||
) | ||
@patch( | ||
"airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", | ||
return_value=None, | ||
) | ||
@patch("astronomer.providers.apache.livy.operators.livy.LivyOperatorAsync.defer") | ||
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) | ||
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state") | ||
def test_livy_operator_async_finish_before_deferred_not_success( | ||
self, mock_get_batch_state, mock_post, mock_defer, mock_dump_logs, mock_state, dag | ||
): | ||
mock_get_batch_state.return_value = mock_state | ||
|
||
task = LivyOperatorAsync( | ||
livy_conn_id="livyunittest", | ||
file="sparkapp", | ||
polling_interval=1, | ||
dag=dag, | ||
task_id="livy_example", | ||
) | ||
with pytest.raises(AirflowException): | ||
task.execute({}) | ||
assert not mock_defer.called | ||
|
||
@patch( | ||
"airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value={"appId": BATCH_ID} | ||
) | ||
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) | ||
def test_livy_operator_async_execute_complete_success(self, mock_post, mock_get, dag): | ||
"""Asserts that a task is completed with success status.""" | ||
task = LivyOperatorAsync( | ||
livy_conn_id="livyunittest", | ||
file="sparkapp", | ||
polling_interval=1, | ||
dag=dag, | ||
task_id="livy_example", | ||
) | ||
assert ( | ||
task.execute_complete( | ||
context={"ti": MagicMock()}, | ||
event={ | ||
"status": "success", | ||
"log_lines": None, | ||
"batch_id": BATCH_ID, | ||
"response": "mock success", | ||
}, | ||
) | ||
is BATCH_ID | ||
) | ||
|
||
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) | ||
def test_livy_operator_async_execute_complete_error(self, mock_post, dag): | ||
"""Asserts that a task is completed with success status.""" | ||
|
||
task = LivyOperatorAsync( | ||
livy_conn_id="livyunittest", | ||
file="sparkapp", | ||
polling_interval=1, | ||
dag=dag, | ||
task_id="livy_example", | ||
) | ||
with pytest.raises(AirflowException): | ||
task.execute_complete( | ||
context={}, | ||
event={ | ||
"status": "error", | ||
"log_lines": ["mock log"], | ||
"batch_id": BATCH_ID, | ||
"response": "mock error", | ||
}, | ||
) | ||
assert isinstance(task, LivyOperator) | ||
assert task.deferrable is True |