Skip to content

Commit

Permalink
Cancel stuck requests on dask scheduler (#136)
Browse files Browse the repository at this point in the history
* fix: update qos when a request is running on a killed worker

* cancel job on the dask scheduler when the broker has not the future

* qa
  • Loading branch information
francesconazzaro authored Oct 10, 2024
1 parent 2906793 commit 39abb14
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 10 deletions.
2 changes: 2 additions & 0 deletions cads_broker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class BrokerConfig(pydantic_settings.BaseSettings):
broker_requeue_limit: int = 3
broker_max_internal_scheduler_tasks: int = 500
broker_max_accepted_requests: int = 2000
broker_cancel_stuck_requests_cache_ttl: int = 60
broker_stuck_requests_limit_hours: int = 1


class SqlalchemySettings(pydantic_settings.BaseSettings):
Expand Down
16 changes: 15 additions & 1 deletion cads_broker/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

logger: structlog.stdlib.BoundLogger = structlog.get_logger(__name__)


status_enum = sa.Enum(
"accepted", "running", "failed", "successful", "dismissed", "deleted", name="status"
)
Expand Down Expand Up @@ -476,6 +475,21 @@ def get_users_queue_from_processing_time(
return queueing_user_costs | running_user_costs


def get_stuck_requests(session: sa.orm.Session, hours: int = 1) -> list[str]:
"""Get all running requests that are not assigned to any worker."""
query = (
sa.select(SystemRequest.request_uid)
.outerjoin(Events, SystemRequest.request_uid == Events.request_uid)
.where(
SystemRequest.status == "running",
SystemRequest.started_at
< sa.func.now() - sa.text(f"interval '{hours} hour'"),
)
.where(Events.event_id.is_(None))
)
return session.execute(query).scalars().all()


def delete_request_qos_status(
request_uid: str,
rules: list,
Expand Down
58 changes: 49 additions & 9 deletions cads_broker/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,37 @@ def get_tasks_on_scheduler(dask_scheduler: distributed.Scheduler) -> dict[str, A
return client.run_on_scheduler(get_tasks_on_scheduler)


def cancel_jobs_on_scheduler(client: distributed.Client, job_ids: list[str]) -> None:
"""Cancel jobs on the dask scheduler.
This function is executed on the scheduler pod. This just cancel the jobs on the scheduler.
See https://stackoverflow.com/questions/49203128/how-do-i-stop-a-running-task-in-dask.
"""

def cancel_jobs(dask_scheduler: distributed.Scheduler, job_ids: list[str]) -> None:
for job_id in job_ids:
if job_id in dask_scheduler.tasks:
dask_scheduler.transitions(
{job_id: "cancelled"}, stimulus_id="manual-cancel"
)

return client.run_on_scheduler(cancel_jobs, job_ids=job_ids)


@cachetools.cached( # type: ignore
cache=cachetools.TTLCache(
maxsize=1024, ttl=CONFIG.broker_cancel_stuck_requests_cache_ttl
),
info=True,
)
def cancel_stuck_requests(client: distributed.Client, session: sa.orm.Session) -> None:
"""Get the stuck requests from the database and cancel them on the dask scheduler."""
stuck_requests = db.get_stuck_requests(
session=session, hours=CONFIG.broker_stuck_requests_limit_hours
)
cancel_jobs_on_scheduler(client, job_ids=stuck_requests)


class Scheduler:
"""A simple scheduler to store the tasks to update the qos_rules in the database.
Expand Down Expand Up @@ -316,8 +347,12 @@ def set_request_error_status(
< CONFIG.broker_requeue_limit
):
logger.info("worker killed: re-queueing", job_id=request_uid)
db.requeue_request(request=request, session=session)
self.queue.add(request_uid, request)
queued_request = db.requeue_request(request=request, session=session)
if queued_request:
self.queue.add(request_uid, request)
self.qos.notify_end_of_request(
request, session, scheduler=self.internal_scheduler
)
else:
request = db.set_request_status(
request_uid,
Expand Down Expand Up @@ -375,9 +410,13 @@ def sync_database(self, session: sa.orm.Session) -> None:
dismissed_requests = db.get_dismissed_requests(
session, limit=CONFIG.broker_max_accepted_requests
)
for i, request in enumerate(dismissed_requests):
for request in dismissed_requests:
if future := self.futures.pop(request.request_uid, None):
future.cancel()
else:
# if the request is not in the futures, it means that the request has been lost by the broker
# try to cancel the job directly on the scheduler
cancel_jobs_on_scheduler(self.client, job_ids=[request.request_uid])
session = self.manage_dismissed_request(request, session)
session.commit()

Expand Down Expand Up @@ -606,12 +645,12 @@ def processing_time_priority_algorithm(
interval_stop = datetime.datetime.now()
# temporary solution to prioritize high priority user
users_queue = {
"27888ffa-0973-4794-9b3c-9efb6767f66f": 0, # wekeo
"d67a13db-86cc-439d-823d-6517003de29f": 0, # CDS Apps user
"365ac1da-090e-4b85-9088-30c676bc5251": 0, # Gionata
"74c6f9a1-8efe-4a6c-b06b-9f8ddcab188d": 0, # User Support
"4d92cc89-d586-4731-8553-07df5dae1886": 0, # Luke Jones
"8d8ee054-6a09-4da8-a5be-d5dff52bbc5f": 0, # Petrut
"27888ffa-0973-4794-9b3c-9efb6767f66f": 0, # wekeo
"d67a13db-86cc-439d-823d-6517003de29f": 0, # CDS Apps user
"365ac1da-090e-4b85-9088-30c676bc5251": 0, # Gionata
"74c6f9a1-8efe-4a6c-b06b-9f8ddcab188d": 0, # User Support
"4d92cc89-d586-4731-8553-07df5dae1886": 0, # Luke Jones
"8d8ee054-6a09-4da8-a5be-d5dff52bbc5f": 0, # Petrut
} | db.get_users_queue_from_processing_time(
interval_stop=interval_stop,
session=session_write,
Expand Down Expand Up @@ -736,6 +775,7 @@ def run(self) -> None:
self.queue.values(), session_write
)

cancel_stuck_requests(client=self.client, session=session_read)
running_requests = len(db.get_running_requests(session=session_read))
queue_length = self.queue.len()
available_workers = self.number_of_workers - running_requests
Expand Down
45 changes: 45 additions & 0 deletions tests/test_02_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,51 @@ def test_count_requests(session_obj: sa.orm.sessionmaker) -> None:
assert 3 == db.count_requests(session=session, status=["accepted", "running"])


def test_get_stuck_running_requests(session_obj: sa.orm.sessionmaker) -> None:
adaptor_properties = mock_config()
request1 = mock_system_request(
status="accepted",
adaptor_properties_hash=adaptor_properties.hash,
)
request2 = mock_system_request(
status="running",
started_at=datetime.datetime.now() - datetime.timedelta(hours=2),
adaptor_properties_hash=adaptor_properties.hash,
)
request_uid2 = request2.request_uid
request3 = mock_system_request(
status="running",
started_at=datetime.datetime.now() - datetime.timedelta(hours=1, minutes=10),
adaptor_properties_hash=adaptor_properties.hash,
)
request_uid3 = request3.request_uid
request4 = mock_system_request(
status="running",
started_at=datetime.datetime.now() - datetime.timedelta(hours=5, minutes=10),
adaptor_properties_hash=adaptor_properties.hash,
)
event4 = mock_event(
request_uid=request4.request_uid,
event_type="worker-name",
message="worker-0",
timestamp=datetime.datetime.now() - datetime.timedelta(minutes=5),
)

with session_obj() as session:
session.add(adaptor_properties)
session.add(request1)
session.add(request2)
session.add(request3)
session.add(request4)
session.add(event4)
session.commit()
requests = db.get_stuck_requests(session=session)

assert len(requests) == 2
assert request_uid2 in requests
assert request_uid3 in requests


def test_add_qos_rule(session_obj: sa.orm.sessionmaker) -> None:
rule = MockRule("rule_name", "conclusion", "info", "condition")
with session_obj() as session:
Expand Down

0 comments on commit 39abb14

Please sign in to comment.