Skip to content

Commit 614be87

Browse files
authored
Added retry to ECS Operator (#14263)
* Added retry to ECS Operator * ... * Remove airflow/www/yarn-error.log * Update decorator to not accept any params * ... * ... * ... * lint * Add predicate argument in retry decorator * Add wraps and fixed test * ... * Remove unnecessary retry_if_permissible_error and fix lint errors * Static check fixes * Fix TestECSOperator.test_execute_with_failures
1 parent a7f2cc2 commit 614be87

File tree

5 files changed

+172
-4
lines changed

5 files changed

+172
-4
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
# Note: Any AirflowException raised is expected to cause the TaskInstance
20+
# to be marked in an ERROR state
21+
22+
23+
class ECSOperatorError(Exception):
24+
"""Raise when ECS cannot handle the request."""
25+
26+
def __init__(self, failures: list, message: str):
27+
self.failures = failures
28+
self.message = message
29+
super().__init__(message)

airflow/providers/amazon/aws/hooks/base_aws.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
import configparser
2828
import datetime
2929
import logging
30-
from typing import Any, Dict, Optional, Tuple, Union
30+
from functools import wraps
31+
from typing import Any, Callable, Dict, Optional, Tuple, Union
3132

3233
import boto3
3334
import botocore
3435
import botocore.session
36+
import tenacity
3537
from botocore.config import Config
3638
from botocore.credentials import ReadOnlyCredentials
3739

@@ -488,6 +490,37 @@ def expand_role(self, role: str) -> str:
488490
else:
489491
return self.get_client_type("iam").get_role(RoleName=role)["Role"]["Arn"]
490492

493+
@staticmethod
494+
def retry(should_retry: Callable[[Exception], bool]):
495+
"""
496+
A decorator that provides a mechanism to repeat requests in response to exceeding a temporary quote
497+
limit.
498+
"""
499+
500+
def retry_decorator(fun: Callable):
501+
@wraps(fun)
502+
def decorator_f(self, *args, **kwargs):
503+
retry_args = getattr(self, 'retry_args', None)
504+
if retry_args is None:
505+
return fun(self)
506+
multiplier = retry_args.get('multiplier', 1)
507+
min_limit = retry_args.get('min', 1)
508+
max_limit = retry_args.get('max', 1)
509+
stop_after_delay = retry_args.get('stop_after_delay', 10)
510+
tenacity_logger = tenacity.before_log(self.log, logging.DEBUG) if self.log else None
511+
default_kwargs = {
512+
'wait': tenacity.wait_exponential(multiplier=multiplier, max=max_limit, min=min_limit),
513+
'retry': tenacity.retry_if_exception(should_retry),
514+
'stop': tenacity.stop_after_delay(stop_after_delay),
515+
'before': tenacity_logger,
516+
'after': tenacity_logger,
517+
}
518+
return tenacity.retry(**default_kwargs)(fun)(self)
519+
520+
return decorator_f
521+
522+
return retry_decorator
523+
491524

492525
def _parse_s3_config(
493526
config_file_name: str, config_format: Optional[str] = "boto", profile: Optional[str] = None

airflow/providers/amazon/aws/operators/ecs.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,24 @@
2525

2626
from airflow.exceptions import AirflowException
2727
from airflow.models import BaseOperator
28+
from airflow.providers.amazon.aws.exceptions import ECSOperatorError
2829
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
2930
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
3031
from airflow.typing_compat import Protocol, runtime_checkable
3132
from airflow.utils.decorators import apply_defaults
3233

3334

35+
def should_retry(exception: Exception):
36+
"""Check if exception is related to ECS resource quota (CPU, MEM)."""
37+
if isinstance(exception, ECSOperatorError):
38+
return any(
39+
quota_reason in failure['reason']
40+
for quota_reason in ['RESOURCE:MEMORY', 'RESOURCE:CPU']
41+
for failure in exception.failures
42+
)
43+
return False
44+
45+
3446
@runtime_checkable
3547
class ECSProtocol(Protocol):
3648
"""
@@ -125,6 +137,8 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
125137
:param reattach: If set to True, will check if a task from the same family is already running.
126138
If so, the operator will attach to it instead of starting a new task.
127139
:type reattach: bool
140+
:param quota_retry: Config if and how to retry _start_task() for transient errors.
141+
:type quota_retry: dict
128142
"""
129143

130144
ui_color = '#f0ede4'
@@ -150,6 +164,7 @@ def __init__(
150164
awslogs_region: Optional[str] = None,
151165
awslogs_stream_prefix: Optional[str] = None,
152166
propagate_tags: Optional[str] = None,
167+
quota_retry: Optional[dict] = None,
153168
reattach: bool = False,
154169
**kwargs,
155170
):
@@ -180,6 +195,7 @@ def __init__(
180195
self.hook: Optional[AwsBaseHook] = None
181196
self.client: Optional[ECSProtocol] = None
182197
self.arn: Optional[str] = None
198+
self.retry_args = quota_retry
183199

184200
def execute(self, context):
185201
self.log.info(
@@ -206,6 +222,7 @@ def execute(self, context):
206222

207223
return None
208224

225+
@AwsBaseHook.retry(should_retry)
209226
def _start_task(self):
210227
run_opts = {
211228
'cluster': self.cluster,
@@ -235,7 +252,7 @@ def _start_task(self):
235252

236253
failures = response['failures']
237254
if len(failures) > 0:
238-
raise AirflowException(response)
255+
raise ECSOperatorError(failures, response)
239256
self.log.info('ECS Task started: %s', response)
240257

241258
self.arn = response['tasks'][0]['taskArn']

tests/providers/amazon/aws/hooks/test_base_aws.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from unittest import mock
2222

2323
import boto3
24+
import pytest
2425

2526
from airflow.models import Connection
2627
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -266,3 +267,82 @@ def test_use_default_boto3_behaviour_without_conn_id(self):
266267
hook = AwsBaseHook(aws_conn_id=conn_id, client_type='s3')
267268
# should cause no exception
268269
hook.get_client_type('s3')
270+
271+
272+
class ThrowErrorUntilCount:
273+
"""Holds counter state for invoking a method several times in a row."""
274+
275+
def __init__(self, count, quota_retry, **kwargs):
276+
self.counter = 0
277+
self.count = count
278+
self.retry_args = quota_retry
279+
self.kwargs = kwargs
280+
self.log = None
281+
282+
def __call__(self):
283+
"""
284+
Raise an Forbidden until after count threshold has been crossed.
285+
Then return True.
286+
"""
287+
if self.counter < self.count:
288+
self.counter += 1
289+
raise Exception()
290+
return True
291+
292+
293+
def _always_true_predicate(e: Exception): # pylint: disable=unused-argument
294+
return True
295+
296+
297+
@AwsBaseHook.retry(_always_true_predicate)
298+
def _retryable_test(thing):
299+
return thing()
300+
301+
302+
def _always_false_predicate(e: Exception): # pylint: disable=unused-argument
303+
return False
304+
305+
306+
@AwsBaseHook.retry(_always_false_predicate)
307+
def _non_retryable_test(thing):
308+
return thing()
309+
310+
311+
class TestRetryDecorator(unittest.TestCase): # ptlint: disable=invalid-name
312+
def test_do_nothing_on_non_exception(self):
313+
result = _retryable_test(lambda: 42)
314+
assert result, 42
315+
316+
def test_retry_on_exception(self):
317+
quota_retry = {
318+
'stop_after_delay': 2,
319+
'multiplier': 1,
320+
'min': 1,
321+
'max': 10,
322+
}
323+
custom_fn = ThrowErrorUntilCount(
324+
count=2,
325+
quota_retry=quota_retry,
326+
)
327+
result = _retryable_test(custom_fn)
328+
assert custom_fn.counter == 2
329+
assert result
330+
331+
def test_no_retry_on_exception(self):
332+
quota_retry = {
333+
'stop_after_delay': 2,
334+
'multiplier': 1,
335+
'min': 1,
336+
'max': 10,
337+
}
338+
custom_fn = ThrowErrorUntilCount(
339+
count=2,
340+
quota_retry=quota_retry,
341+
)
342+
with pytest.raises(Exception):
343+
_non_retryable_test(custom_fn)
344+
345+
def test_raise_exception_when_no_retry_args(self):
346+
custom_fn = ThrowErrorUntilCount(count=2, quota_retry=None)
347+
with pytest.raises(Exception):
348+
_retryable_test(custom_fn)

tests/providers/amazon/aws/operators/test_ecs.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from parameterized import parameterized
2727

2828
from airflow.exceptions import AirflowException
29-
from airflow.providers.amazon.aws.operators.ecs import ECSOperator
29+
from airflow.providers.amazon.aws.exceptions import ECSOperatorError
30+
from airflow.providers.amazon.aws.operators.ecs import ECSOperator, should_retry
3031

3132
# fmt: off
3233
RESPONSE_WITHOUT_FAILURES = {
@@ -145,7 +146,7 @@ def test_execute_with_failures(self):
145146
resp_failures['failures'].append('dummy error')
146147
client_mock.run_task.return_value = resp_failures
147148

148-
with pytest.raises(AirflowException):
149+
with pytest.raises(ECSOperatorError):
149150
self.ecs.execute(None)
150151

151152
self.aws_hook_mock.return_value.get_conn.assert_called_once()
@@ -326,3 +327,11 @@ def test_execute_xcom_with_no_log(self, mock_cloudwatch_log_message):
326327
def test_execute_xcom_disabled(self, mock_cloudwatch_log_message):
327328
self.ecs.do_xcom_push = False
328329
assert self.ecs.execute(None) is None
330+
331+
332+
class TestShouldRetry(unittest.TestCase):
333+
def test_return_true_on_valid_reason(self):
334+
self.assertTrue(should_retry(ECSOperatorError([{'reason': 'RESOURCE:MEMORY'}], 'Foo')))
335+
336+
def test_return_false_on_invalid_reason(self):
337+
self.assertFalse(should_retry(ECSOperatorError([{'reason': 'CLUSTER_NOT_FOUND'}], 'Foo')))

0 commit comments

Comments
 (0)