Skip to content

Commit

Permalink
Merge pull request #6535 from MetRonnie/kill-prep
Browse files Browse the repository at this point in the history
Kill tasks during job prep
  • Loading branch information
hjoliver authored Feb 13, 2025
2 parents 1593ef3 + c28a565 commit 9fe7f51
Show file tree
Hide file tree
Showing 24 changed files with 507 additions and 215 deletions.
1 change: 1 addition & 0 deletions changes.d/6535.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure tasks can be killed while in the preparing state.
16 changes: 10 additions & 6 deletions cylc/flow/network/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ class WorkflowRuntimeServer:
"""
endpoints: Dict[str, object]
curve_auth: ThreadAuthenticator
"""The ZMQ authenticator."""
client_pub_key_dir: str
"""Client public key directory, used by the ZMQ authenticator."""

OPERATE_SLEEP_INTERVAL = 0.2
STOP_SLEEP_INTERVAL = 0.2
Expand All @@ -136,8 +140,6 @@ def __init__(self, schd):
self.publisher = None
self.loop = None
self.thread = None
self.curve_auth = None
self.client_pub_key_dir = None

self.schd: 'Scheduler' = schd
self.resolvers = Resolvers(
Expand Down Expand Up @@ -184,10 +186,7 @@ def start(self, barrier):
self.client_pub_key_dir = client_pub_keyinfo.key_path

# Initial load for the localhost key.
self.curve_auth.configure_curve(
domain='*',
location=(self.client_pub_key_dir)
)
self.configure_curve()

min_, max_ = glbl_cfg().get(['scheduler', 'run hosts', 'ports'])
self.replier = WorkflowReplier(self, context=self.zmq_context)
Expand All @@ -207,6 +206,11 @@ def start(self, barrier):

self.operate()

def configure_curve(self) -> None:
self.curve_auth.configure_curve(
domain='*', location=self.client_pub_key_dir
)

async def stop(self, reason: Union[BaseException, str]) -> None:
"""Stop the TCP servers, and clean up authentication.
Expand Down
60 changes: 34 additions & 26 deletions cylc/flow/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections import deque
from contextlib import suppress
import itertools
import logging
import os
from pathlib import Path
from queue import (
Expand Down Expand Up @@ -82,6 +83,7 @@
FLOW_NONE,
FlowMgr,
repr_flow_nums,
stringify_flow_nums,
)
from cylc.flow.host_select import (
HostSelectException,
Expand Down Expand Up @@ -440,7 +442,8 @@ async def initialise(self):
self.workflow_db_mgr,
self.task_events_mgr,
self.data_store_mgr,
self.bad_hosts
self.bad_hosts,
self.server,
)

self.profiler = Profiler(self, self.options.profile_mode)
Expand Down Expand Up @@ -910,9 +913,7 @@ def restart_remote_init(self):
if install_target == get_localhost_install_target():
continue
# set off remote init
self.task_job_mgr.task_remote_mgr.remote_init(
platform, self.server.curve_auth,
self.server.client_pub_key_dir)
self.task_job_mgr.task_remote_mgr.remote_init(platform)
# Remote init/file-install is done via process pool
self.proc_pool.process()
# add platform to map (to be picked up on main loop)
Expand Down Expand Up @@ -1078,18 +1079,21 @@ def kill_tasks(
to_kill: List[TaskProxy] = []
unkillable: List[TaskProxy] = []
for itask in itasks:
if itask.state(*TASK_STATUSES_ACTIVE):
if itask.state_reset(is_held=True):
self.data_store_mgr.delta_task_state(itask)
if not itask.state(TASK_STATUS_PREPARING, *TASK_STATUSES_ACTIVE):
unkillable.append(itask)
continue
if itask.state_reset(is_held=True):
self.data_store_mgr.delta_task_state(itask)
if itask.state(TASK_STATUS_PREPARING):
self.task_job_mgr.kill_prep_task(itask)
else:
to_kill.append(itask)
if jobless:
# Directly set failed in sim mode:
self.task_events_mgr.process_message(
itask, 'CRITICAL', TASK_STATUS_FAILED,
flag=self.task_events_mgr.FLAG_RECEIVED
)
else:
unkillable.append(itask)
if warn and unkillable:
LOG.warning(
"Tasks not killable: "
Expand Down Expand Up @@ -1250,6 +1254,7 @@ def get_contact_data(self) -> Dict[str, str]:
"""
fields = workflow_files.ContactFileFields
proc = psutil.Process()
platform = get_platform()
# fmt: off
return {
fields.API:
Expand All @@ -1275,11 +1280,11 @@ def get_contact_data(self) -> Dict[str, str]:
fields.VERSION:
CYLC_VERSION,
fields.SCHEDULER_SSH_COMMAND:
str(get_platform()['ssh command']),
str(platform['ssh command']),
fields.SCHEDULER_CYLC_PATH:
str(get_platform()['cylc path']),
str(platform['cylc path']),
fields.SCHEDULER_USE_LOGIN_SHELL:
str(get_platform()['use login shell'])
str(platform['use login shell'])
}
# fmt: on

Expand Down Expand Up @@ -1531,29 +1536,32 @@ def start_job_submission(self, itasks: 'Iterable[TaskProxy]') -> bool:
self.task_job_mgr.task_remote_mgr.rsync_includes = (
self.config.get_validated_rsync_includes())

log = LOG.debug
submitted = self.submit_task_jobs(itasks)
if not submitted:
return False

log_lvl = logging.DEBUG
if self.options.reftest or self.options.genref:
log = LOG.info
log_lvl = logging.INFO

for itask in self.task_job_mgr.submit_task_jobs(
self.workflow,
itasks,
self.server.curve_auth,
self.server.client_pub_key_dir,
run_mode=self.get_run_mode()
):
if itask.flow_nums:
flow = ','.join(str(i) for i in itask.flow_nums)
else:
flow = FLOW_NONE
log(
for itask in submitted:
flow = stringify_flow_nums(itask.flow_nums) or FLOW_NONE
LOG.log(
log_lvl,
f"{itask.identity} -triggered off "
f"{itask.state.get_resolved_dependencies()} in flow {flow}"
)

# one or more tasks were passed through the submission pipeline
return True

def submit_task_jobs(
self, itasks: 'Iterable[TaskProxy]'
) -> 'List[TaskProxy]':
"""Submit task jobs, return tasks that attempted submission."""
# Note: keep this as simple wrapper for task job mgr's method
return self.task_job_mgr.submit_task_jobs(itasks, self.get_run_mode())

def process_workflow_db_queue(self):
"""Update workflow DB."""
self.workflow_db_mgr.process_queued_ops()
Expand Down
4 changes: 4 additions & 0 deletions cylc/flow/subprocctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from shlex import quote
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union


from cylc.flow.wallclock import get_current_time_string

if TYPE_CHECKING:
Expand Down Expand Up @@ -137,6 +138,9 @@ def __str__(self):
'mesg': mesg}
return ret.rstrip()

def __repr__(self) -> str:
return f"<{type(self).__name__} {self.cmd_key}>"

Check warning on line 142 in cylc/flow/subprocctx.py

View check run for this annotation

Codecov / codecov/patch

cylc/flow/subprocctx.py#L142

Added line #L142 was not covered by tests


class SubFuncContext(SubProcContext):
"""Represent the context of a Python function to run as a subprocess.
Expand Down
Loading

0 comments on commit 9fe7f51

Please sign in to comment.