Skip to content

Commit

Permalink
Update wait_for_worker checks to check instance is running and reacha…
Browse files Browse the repository at this point in the history
…ble (#5101)
  • Loading branch information
stacimc authored Oct 28, 2024
1 parent 0868741 commit f9b1769
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 10 deletions.
31 changes: 21 additions & 10 deletions catalog/dags/data_refresh/distributed_reindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,19 @@ def create_worker(
return instances[0]["InstanceId"]


def _get_reachability_status(instance_data, status_type):
status_data = instance_data.get(status_type, {})
reachability_status = next(
(
status.get("Status")
for status in status_data.get("Details", [])
if status.get("Name") == "reachability"
),
None,
)
return reachability_status == "passed" and status_data.get("Status") == "ok"


@task.sensor(poke_interval=60, timeout=3600, mode="reschedule")
@setup_ec2_hook
def wait_for_worker(
Expand All @@ -194,24 +207,22 @@ def wait_for_worker(
raise AirflowSkipException("Skipping instance creation in local environment.")

result = ec2_hook.conn.describe_instance_status(InstanceIds=[instance_id])

logger.info(result)

instance_statuses = result.get("InstanceStatuses", [])
instance_statuses = result.get("InstanceStatuses")
instance_status = instance_statuses[0] if instance_statuses else {}

state = instance_status.get("InstanceState", {}).get("Name")
status = next(
(
status.get("Status")
for status in instance_status.get("InstanceStatus", {}).get("Details", [])
if status.get("Name") == "reachability"
),
None,
is_reachable = all(
[
_get_reachability_status(instance_status, status_type)
for status_type in ["InstanceStatus", "SystemStatus", "AttachedEbsStatus"]
]
)

return PokeReturnValue(
# Sensor completes only when the instance is running and has finished initializing
is_done=(state == "running" and status == "ok")
is_done=(state == "running" and is_reachable)
)


Expand Down
134 changes: 134 additions & 0 deletions catalog/tests/dags/data_refresh/test_distributed_reindex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import logging
from unittest import mock

import pytest
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook

from common.constants import PRODUCTION
from data_refresh.distributed_reindex import wait_for_worker


logger = logging.getLogger(__name__)


def _build_instance_status_response(
instance_state,
instance_status,
system_status,
attached_ebs_status,
instance_reachability_status,
system_reachability_status,
attached_ebs_reachability_status,
):
return {
"InstanceStatuses": [
{
"AvailabilityZone": "us-east-1a",
"InstanceId": "i-01234abc",
"InstanceState": {"Code": 16, "Name": instance_state},
"InstanceStatus": {
"Details": [
{"Name": "reachability", "Status": instance_reachability_status}
],
"Status": instance_status,
},
"SystemStatus": {
"Details": [
{"Name": "reachability", "Status": system_reachability_status}
],
"Status": system_status,
},
"AttachedEbsStatus": {
"Details": [
{
"Name": "reachability",
"Status": attached_ebs_reachability_status,
}
],
"Status": attached_ebs_status,
},
}
],
"ResponseMetadata": {
"RequestId": "123abc-def-456-789",
"HTTPStatusCode": 200,
"HTTPHeaders": {
"x-amzn-requestid": "123abc-def-456-789",
"cache-control": "no-cache, no-store",
"strict-transport-security": "max-age=31536000; includeSubDomains",
"content-type": "text/xml;charset=UTF-8",
"content-length": "819",
"date": "Fri, 25 Oct 2024 20:18:14 GMT",
"server": "AmazonEC2",
},
"RetryAttempts": 0,
},
}


@pytest.fixture
def mock_ec2_hook() -> EC2Hook:
return mock.MagicMock(spec=EC2Hook)


@pytest.mark.parametrize(
"instance_status_response, should_pass",
[
(
_build_instance_status_response(
"pending",
"initializing",
"initializing",
"initializing",
"initializing",
"initializing",
"initializing",
),
False,
),
(
_build_instance_status_response(
"running",
"initializing",
"ok",
"ok",
"initializing",
"passed",
"passed",
),
False,
),
(
_build_instance_status_response(
"running", "insufficient_data", "ok", "ok", "failed", "passed", "passed"
),
False,
),
(
_build_instance_status_response(
"running", "impaired", "ok", "impaired", "passed", "failed", "failed"
),
False,
),
(
_build_instance_status_response(
"running", "ok", "ok", "impaired", "passed", "passed", "failed"
),
False,
),
# Only pass when instance is running and all status are reachable
(
_build_instance_status_response(
"running", "ok", "ok", "ok", "passed", "passed", "passed"
),
True,
),
],
)
def test_wait_for_worker(instance_status_response, should_pass, mock_ec2_hook):
mock_ec2_hook.conn.describe_instance_status.return_value = instance_status_response
poke_return_value = wait_for_worker.function(
environment=PRODUCTION, instance_id="i-01234abc", ec2_hook=mock_ec2_hook
)

assert poke_return_value.is_done == should_pass

0 comments on commit f9b1769

Please sign in to comment.