Skip to content

Commit

Permalink
Implement signal cancellation (#1154)
Browse files Browse the repository at this point in the history
  • Loading branch information
brilee authored Jan 31, 2024
1 parent 59085ef commit 6220068
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 21 deletions.
6 changes: 6 additions & 0 deletions lilac/router_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@
def get_task_manifest() -> TaskManifest:
"""Get the tasks, both completed and pending."""
return get_task_manager().manifest()


@router.post('/{task_id}/cancel')
def cancel_task(task_id: str) -> None:
"""Cancel a task."""
get_task_manager().cancel_task(task_id)
64 changes: 43 additions & 21 deletions lilac/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import functools
import uuid
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
Expand Down Expand Up @@ -74,25 +75,29 @@ class TaskManifest(BaseModel):
class TaskManager:
"""Manage FastAPI background tasks."""

_tasks: dict[TaskId, TaskInfo]
_task_info: dict[TaskId, TaskInfo]
_task_threads: dict[TaskId, Thread]
_task_stopped: dict[TaskId, bool] # If true, task cancellation was requested.

def __init__(self) -> None:
# Maps a task id to the current progress of that task. Shared across all processes.
self._tasks = {}
self._task_info = {}
self._task_threads = {}
self._task_stopped = defaultdict(lambda: False)

def get_task_info(self, task_id: TaskId) -> TaskInfo:
"""Get the task info for a task."""
return self._tasks[task_id]
return self._task_info[task_id]

def manifest(self) -> TaskManifest:
"""Get all tasks."""
tasks_with_progress = [
(task.total_progress / task.total_len)
for task in self._tasks.values()
for task in self._task_info.values()
if task.total_progress and task.total_len and task.status != TaskStatus.COMPLETED
]
return TaskManifest(
tasks=self._tasks,
tasks=self._task_info,
progress=sum(tasks_with_progress) / len(tasks_with_progress) if tasks_with_progress else None,
)

Expand All @@ -113,12 +118,38 @@ def task_id(
start_timestamp=datetime.now().isoformat(),
total_len=total_len,
)
self._tasks[task_id] = new_task
self._task_info[task_id] = new_task
return task_id

def launch_task(self, task_id: TaskId, run_fn: Callable[..., Any]) -> None:
"""Start a task in a background thread."""

def _wrapper() -> None:
try:
run_fn()
except Exception as e:
log(e)
self.set_error(task_id, str(e))
else:
self.set_completed(task_id)

thread = Thread(target=_wrapper, daemon=True)
thread.start()
self._task_threads[task_id] = thread

def cancel_task(self, task_id: TaskId) -> None:
"""Mark a thread for cancellation.
The thread is not guaranteed to stop unless you also use get_progress_bar. If you implement
your own task execution logic, you can check tm._task_stopped[task_id] to see if the task
has been cancelled.
"""
self._task_stopped[task_id] = True
self._task_info[task_id].message = 'Task cancellation requested.'

def report_progress(self, task_id: TaskId, progress: int) -> None:
"""Report the progress of a task."""
task = self._tasks[task_id]
task = self._task_info[task_id]
task.total_progress = progress
elapsed_sec = (datetime.now() - datetime.fromisoformat(task.start_timestamp)).total_seconds()
ex_per_sec = progress / elapsed_sec if elapsed_sec else 0
Expand All @@ -128,15 +159,15 @@ def report_progress(self, task_id: TaskId, progress: int) -> None:

def set_error(self, task_id: TaskId, error: str) -> None:
"""Mark a task as errored."""
task = self._tasks[task_id]
task = self._task_info[task_id]
task.status = TaskStatus.ERROR
task.error = error
task.end_timestamp = datetime.now().isoformat()

def set_completed(self, task_id: TaskId) -> None:
"""Mark a task completed."""
end_timestamp = datetime.now().isoformat()
task = self._tasks[task_id]
task = self._task_info[task_id]
task.end_timestamp = end_timestamp

elapsed = datetime.fromisoformat(end_timestamp) - datetime.fromisoformat(task.start_timestamp)
Expand Down Expand Up @@ -180,6 +211,8 @@ def progress_reporter(it: Iterator[TProgress]) -> Iterator[TProgress]:
progress = offset
try:
for item in tqdm(it, initial=progress, total=task_info.total_len, desc=task_info.description):
if task_manager._task_stopped[task_id]:
raise AssertionError('Task cancelled successfully!')
progress += 1
if progress % 100 == 0:
task_manager.report_progress(task_id, progress)
Expand All @@ -195,15 +228,4 @@ def progress_reporter(it: Iterator[TProgress]) -> Iterator[TProgress]:
def launch_task(task_id: TaskId, run_fn: Callable) -> None:
"""Launch a task in a thread, handling exit conditions, etc.."""
tm = get_task_manager()

def _wrapper() -> None:
try:
run_fn()
except Exception as e:
log(e)
tm.set_error(task_id, str(e))
else:
tm.set_completed(task_id)

thread = Thread(target=_wrapper, daemon=True)
thread.start()
tm.launch_task(task_id, run_fn)
1 change: 1 addition & 0 deletions web/blueprint/src/lib/components/TaskStatus.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
<div class="text-s flex flex-row">
<div class="mr-2">{task.name}</div>
</div>
<button>Cancel button</button>
<div class="progress-container mt-3">
<ProgressBar
labelText={message || ''}
Expand Down
22 changes: 22 additions & 0 deletions web/lib/fastapi_client/services/TasksService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,26 @@ export class TasksService {
});
}

/**
* Cancel Task
* Cancel a task.
* @param taskId
* @returns any Successful Response
* @throws ApiError
*/
public static cancelTask(
taskId: string,
): CancelablePromise<any> {
return __request(OpenAPI, {
method: 'POST',
url: '/api/v1/tasks/{task_id}/cancel',
path: {
'task_id': taskId,
},
errors: {
422: `Validation Error`,
},
});
}

}

0 comments on commit 6220068

Please sign in to comment.