From d836e2b655f026f7571444f22dc53c5ae71cf29b Mon Sep 17 00:00:00 2001 From: Danilo Horta Date: Thu, 24 Oct 2024 17:23:55 +0100 Subject: [PATCH] =?UTF-8?q?Add=20progress=20tracking=20and=20new=20worker?= =?UTF-8?q?=20features=20=F0=9F=9A=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Introduced `Progress` class in `deciphonctl/cli.py` for real-time upload tracking. - Updated `poster.upload` calls to utilize progress feedback with `Progress`. - Bumped `deciphon-poster` version in `pyproject.toml` to `^0.2.1` to ensure compatibility. - Added new module `progress.py` to handle upload progress tracking using `rich.progress`. - Set up foundational worker architecture with new modules: `alarm.py`, `background.py`, `cli.py`, `download.py`, and `scanner.py`. This upgrade improves user experience by providing detailed progress feedback during uploads, ensuring seamless file transfers and better usability. The new worker modules lay the groundwork for enhanced background task handling, efficient file downloads, and real-time scanner operations through MQTT. --- control/deciphonctl/cli.py | 9 ++- control/deciphonctl/progress.py | 27 ++++++++ control/pyproject.toml | 2 +- worker/deciphon_worker/__init__.py | 0 worker/deciphon_worker/alarm.py | 29 +++++++++ worker/deciphon_worker/background.py | 40 ++++++++++++ worker/deciphon_worker/cli.py | 30 +++++++++ worker/deciphon_worker/download.py | 22 +++++++ worker/deciphon_worker/scanner.py | 93 ++++++++++++++++++++++++++++ 9 files changed, 249 insertions(+), 3 deletions(-) create mode 100644 control/deciphonctl/progress.py create mode 100644 worker/deciphon_worker/__init__.py create mode 100644 worker/deciphon_worker/alarm.py create mode 100644 worker/deciphon_worker/background.py create mode 100644 worker/deciphon_worker/cli.py create mode 100644 worker/deciphon_worker/download.py create mode 100644 worker/deciphon_worker/scanner.py diff --git a/control/deciphonctl/cli.py b/control/deciphonctl/cli.py index 86e2259..065062f 100644 --- a/control/deciphonctl/cli.py +++ b/control/deciphonctl/cli.py @@ -16,6 +16,7 @@ from deciphonctl.catch_validation import catch_validation from deciphonctl.display_exception import display_exception from deciphonctl.log_level import LogLevel +from deciphonctl.progress import Progress from deciphonctl.settings import ( Settings, SettingsFields, @@ -134,7 +135,9 @@ def config_get(option: CFGOPT): def hmm_add(hmmfile: HMMFILE, gencode: GENCODE, epsilon: EPSILON = 0.01): settings = Settings() poster = Poster(settings.sched_url, settings.s3_url) - poster.upload(hmmfile, poster.upload_hmm_post(hmmfile.name)) + upload_post = poster.upload_hmm_post(hmmfile.name) + with Progress() as progress: + poster.upload(hmmfile, upload_post, progress) poster.hmm_post(HMMName(name=hmmfile.name), gencode, epsilon) @@ -162,7 +165,9 @@ def hmm_ls(): def db_add(dbfile: DBFILE, gencode: GENCODE, epsilon: EPSILON = 0.01): settings = Settings() poster = Poster(settings.sched_url, settings.s3_url) - poster.upload(dbfile, poster.upload_db_post(dbfile.name)) + upload_post = poster.upload_db_post(dbfile.name) + with Progress() as progress: + poster.upload(dbfile, upload_post, progress) poster.db_post(DBName(name=dbfile.name)) diff --git a/control/deciphonctl/progress.py b/control/deciphonctl/progress.py new file mode 100644 index 0000000..d3ba918 --- /dev/null +++ b/control/deciphonctl/progress.py @@ -0,0 +1,27 @@ +from types import TracebackType +from typing import Optional, Type + +import rich.progress + + +class Progress: + def __init__(self): + self._progress = rich.progress.Progress() + self._task: rich.progress.TaskID | None = None + + def __call__(self, total_size: int, bytes_read: int): + if self._task is None: + self._task = self._progress.add_task("Downloading", total=total_size) + self._progress.update(self._task, completed=bytes_read) + + def __enter__(self): + self._progress.__enter__() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ): + self._progress.__exit__(exc_type, exc_val, exc_tb) diff --git a/control/pyproject.toml b/control/pyproject.toml index fcf5136..5cc709a 100644 --- a/control/pyproject.toml +++ b/control/pyproject.toml @@ -18,7 +18,7 @@ loguru = "^0.7" requests-toolbelt = "^1.0" paho-mqtt = "^2.1" deciphon = "^1.0.0" -deciphon-poster = "^0.1.2" +deciphon-poster = "^0.2.1" xdg-base-dirs = "^6.0.2" python-dotenv = "^1.0.1" diff --git a/worker/deciphon_worker/__init__.py b/worker/deciphon_worker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/worker/deciphon_worker/alarm.py b/worker/deciphon_worker/alarm.py new file mode 100644 index 0000000..f1e3ab6 --- /dev/null +++ b/worker/deciphon_worker/alarm.py @@ -0,0 +1,29 @@ +from threading import Event, Thread +from typing import Callable + + +class Alarm: + def __init__(self, seconds: float, callback: Callable[[], None]): + self._seconds = seconds + self._callback = callback + self._stop = Event() + self._thread = Thread(target=self._loop) + + def _loop(self): + while not self._stop.wait(timeout=self._seconds): + self._callback() + + def start(self): + self._stop.clear() + self._thread.start() + + def stop(self): + self._stop.set() + self._thread.join() + + def __enter__(self): + self.start() + return self + + def __exit__(self, *_): + self.stop() diff --git a/worker/deciphon_worker/background.py b/worker/deciphon_worker/background.py new file mode 100644 index 0000000..9b8d926 --- /dev/null +++ b/worker/deciphon_worker/background.py @@ -0,0 +1,40 @@ +from queue import Queue +from threading import Thread +from typing import Callable + +callback_type = Callable[[], None] + + +class Background: + def __init__(self): + self._queue: Queue[callback_type | None] = Queue() + self._thread = Thread(target=self.loop) + + def fire(self, callback: callback_type): + self._queue.put(callback) + + def shutdown(self): + self._queue.put(None) + + def loop(self): + while True: + callback = self._queue.get() + if callback is None: + break + callback() + self._queue.task_done() + + def start(self): + assert self._queue.empty() + self._thread.start() + + def stop(self): + self.shutdown() + self._thread.join() + + def __enter__(self): + self.start() + return self + + def __exit__(self, *_): + self.stop() diff --git a/worker/deciphon_worker/cli.py b/worker/deciphon_worker/cli.py new file mode 100644 index 0000000..3682916 --- /dev/null +++ b/worker/deciphon_worker/cli.py @@ -0,0 +1,30 @@ +from typing import Optional + +import typer +from deciphon_poster.poster import Poster + +from deciphon_worker.presser import presser_loop +from deciphon_worker.scanner import scanner_loop +from deciphon_worker.url import http_url + +app = typer.Typer( + add_completion=False, + pretty_exceptions_short=True, + pretty_exceptions_show_locals=False, +) + + +@app.command() +def scanner( + sched_url: str, mqtt_host: str, mqtt_port: int, s3_url: Optional[str] = None +): + poster = Poster(http_url(sched_url), s3_url if s3_url is None else http_url(s3_url)) + scanner_loop(poster, mqtt_host, mqtt_port) + + +@app.command() +def presser( + sched_url: str, mqtt_host: str, mqtt_port: int, s3_url: Optional[str] = None +): + poster = Poster(http_url(sched_url), s3_url if s3_url is None else http_url(s3_url)) + presser_loop(poster, mqtt_host, mqtt_port) diff --git a/worker/deciphon_worker/download.py b/worker/deciphon_worker/download.py new file mode 100644 index 0000000..120bcc3 --- /dev/null +++ b/worker/deciphon_worker/download.py @@ -0,0 +1,22 @@ +import shlex +import subprocess +from pathlib import Path + +import requests +from loguru import logger +from pydantic import HttpUrl + + +def download(url: HttpUrl, dst: Path): + logger.info(f"downloading {url}") + u = url.unicode_string() + try: + cmd = shlex.join(["curl", "--silent", "-L", u]) + " > " + shlex.join([str(dst)]) + subprocess.check_call(cmd, shell=True) + except subprocess.CalledProcessError: + chunk_size = 4_194_304 + with requests.get(u, stream=True) as r: + r.raise_for_status() + with open(dst, "wb") as f: + for chunk in r.iter_content(chunk_size=chunk_size): + f.write(chunk) diff --git a/worker/deciphon_worker/scanner.py b/worker/deciphon_worker/scanner.py new file mode 100644 index 0000000..aa40022 --- /dev/null +++ b/worker/deciphon_worker/scanner.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from functools import partial +from pathlib import Path +from queue import Queue + +from deciphon_core.schema import HMMFile, HMMName +from deciphon_poster.poster import Poster +from deciphon_poster.schema import JobUpdate +from loguru import logger +from paho.mqtt.client import CallbackAPIVersion, Client + +from deciphon_worker.background import Background +from deciphon_worker.download import download +from deciphon_worker.files import atomic_file_creation +from deciphon_worker.models import ScanRequest +from deciphon_worker.scan_thread import ScanThread + +FILE_MODE = 0o640 +TOPIC = "/deciphon.org/scan" + + +def on_connect(client, userdata, flags, reason_code, properties): + logger.info(f"connected to MQTT with result code {reason_code}") + logger.info(f"subscribing to {TOPIC}") + client.subscribe(TOPIC) + + +def on_message(client, userdata, msg): + assert isinstance(msg.payload, bytes) + payload = msg.payload.decode() + logger.info(f"received <{payload}>") + requests: Queue[ScanRequest] = userdata + requests.put(ScanRequest.model_validate_json(payload)) + + +def scanner_hash(hmm: HMMName, multi_hits: bool, hmmer3_compat: bool): + return hash(f"{str(hmm)}_{multi_hits}_{hmmer3_compat}") + + +def process_request( + scans: dict[int, ScanThread], + bg: Background, + poster: Poster, + request: ScanRequest, +): + logger.info(f"processing scan request: {request}") + + hmmfile = Path(request.hmm.name) + dbfile = Path(request.db.name) + + if not hmmfile.exists(): + with atomic_file_creation(hmmfile) as t: + download(poster.download_hmm_url(hmmfile.name), t) + + if not dbfile.exists(): + with atomic_file_creation(dbfile) as t: + download(poster.download_db_url(dbfile.name), t) + + id = scanner_hash(request.hmm, request.multi_hits, request.hmmer3_compat) + if id not in scans: + hmm = HMMFile(path=hmmfile) + scans[id] = ScanThread( + bg, poster, hmm, request.multi_hits, request.hmmer3_compat + ) + scans[id].start() + + scans[id].fire(request) + + +def scanner_loop(poster: Poster, mqtt_host: str, mqtt_port: int): + requests: Queue[ScanRequest] = Queue() + scans: dict[int, ScanThread] = dict() + + logger.info(f"connecting to MQTT server (host={mqtt_host}, port={mqtt_port})") + client = Client(CallbackAPIVersion.VERSION2, userdata=requests) + client.on_connect = on_connect + client.on_message = on_message + client.connect(mqtt_host, mqtt_port) + + client.loop_start() + with Background() as bg: + while True: + request = requests.get() + try: + process_request(scans, bg, poster, request) + except Exception as exception: + logger.warning(f"scanning failed: {exception}") + job_update = JobUpdate.fail(request.job_id, str(exception)) + bg.fire(partial(poster.job_patch, job_update)) + finally: + requests.task_done() + client.loop_stop()