From f733b3a6fa26160800fdf0a3dd18e896e848280a Mon Sep 17 00:00:00 2001 From: Hank Doupe Date: Sat, 3 Oct 2020 11:22:37 -0400 Subject: [PATCH] Add back timeout for sims --- workers/cs_workers/models/clients/job.py | 12 +++++++++++- workers/cs_workers/models/executors/api_task.py | 4 +++- workers/cs_workers/models/executors/job.py | 5 ++++- workers/cs_workers/models/executors/task_wrapper.py | 11 +++++++++-- 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/workers/cs_workers/models/clients/job.py b/workers/cs_workers/models/clients/job.py index 38741f08..cbc2916c 100644 --- a/workers/cs_workers/models/clients/job.py +++ b/workers/cs_workers/models/clients/job.py @@ -99,6 +99,7 @@ def configure(self, owner, title, tag, job_id=None): job_id = str(job_id) config = self.model_config.projects()[f"{owner}/{title}"] + timeout = config["expected_task_time"] * 1.25 safeowner = clean(owner) safetitle = clean(title) @@ -106,7 +107,16 @@ def configure(self, owner, title, tag, job_id=None): container = kclient.V1Container( name=job_id, image=f"{self.cr}/{self.project}/{safeowner}_{safetitle}_tasks:{tag}", - command=["csw", "job", "--job-id", job_id, "--route-name", "sim"], + command=[ + "csw", + "job", + "--job-id", + job_id, + "--route-name", + "sim", + "--timeout", + timeout, + ], env=self.env(owner, title, config), resources=kclient.V1ResourceRequirements(**config["resources"]), ) diff --git a/workers/cs_workers/models/executors/api_task.py b/workers/cs_workers/models/executors/api_task.py index 0ef4c8ff..e26813ba 100644 --- a/workers/cs_workers/models/executors/api_task.py +++ b/workers/cs_workers/models/executors/api_task.py @@ -46,7 +46,9 @@ async def post(self): if task_id is None: task_id = str(uuid.uuid4()) task_kwargs = payload.get("task_kwargs") - async_task = async_task_wrapper(task_id, task_name, handler, task_kwargs) + async_task = async_task_wrapper( + task_id, task_name, handler, timeout=None, task_kwargs=task_kwargs + ) asyncio.create_task(async_task) self.set_status(200) self.write({"status": "PENDING", "task_id": task_id}) diff --git a/workers/cs_workers/models/executors/job.py b/workers/cs_workers/models/executors/job.py index a0ad740e..f1821c54 100644 --- a/workers/cs_workers/models/executors/job.py +++ b/workers/cs_workers/models/executors/job.py @@ -39,7 +39,9 @@ def sim_handler(task_id, meta_param_dict, adjustment): def main(args: argparse.Namespace): asyncio.run( - async_task_wrapper(args.job_id, args.route_name, routes[args.route_name]) + async_task_wrapper( + args.job_id, args.route_name, routes[args.route_name], timeout=args.timeout + ) ) @@ -47,4 +49,5 @@ def cli(subparsers: argparse._SubParsersAction): parser = subparsers.add_parser("job", description="CLI for C/S jobs.") parser.add_argument("--job-id", "-t", required=True) parser.add_argument("--route-name", "-r", required=True) + parser.add_argument("--timeout", required=False, type=int) parser.set_defaults(func=main) diff --git a/workers/cs_workers/models/executors/task_wrapper.py b/workers/cs_workers/models/executors/task_wrapper.py index 8dccae06..6373feca 100644 --- a/workers/cs_workers/models/executors/task_wrapper.py +++ b/workers/cs_workers/models/executors/task_wrapper.py @@ -1,3 +1,4 @@ +import asyncio import functools import json import os @@ -51,7 +52,7 @@ async def sync_task_wrapper(task_id, task_name, func, task_kwargs=None): return res -async def async_task_wrapper(task_id, task_name, func, task_kwargs=None): +async def async_task_wrapper(task_id, task_name, func, timeout=None, task_kwargs=None): print("async task", task_id, func, task_kwargs) start = time.time() traceback_str = None @@ -66,7 +67,13 @@ async def async_task_wrapper(task_id, task_name, func, task_kwargs=None): task_kwargs = rclient.get(_task_id) if task_kwargs is not None: task_kwargs = json.loads(task_kwargs.decode()) - outputs = func(task_id, **(task_kwargs or {})) + + if timeout: + loop = asyncio.get_event_loop() + fut = loop.run_in_executor(None, func, task_id, **(task_kwargs or {})) + outputs = await asyncio.wait_for(fut, timeout=timeout) + else: + outputs = func(task_id, **(task_kwargs or {})) res.update( { "model_version": functions.get_version(),