-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add progress tracking and new worker features 🚀
- Introduced `Progress` class in `deciphonctl/cli.py` for real-time upload tracking. - Updated `poster.upload` calls to utilize progress feedback with `Progress`. - Bumped `deciphon-poster` version in `pyproject.toml` to `^0.2.1` to ensure compatibility. - Added new module `progress.py` to handle upload progress tracking using `rich.progress`. - Set up foundational worker architecture with new modules: `alarm.py`, `background.py`, `cli.py`, `download.py`, and `scanner.py`. This upgrade improves user experience by providing detailed progress feedback during uploads, ensuring seamless file transfers and better usability. The new worker modules lay the groundwork for enhanced background task handling, efficient file downloads, and real-time scanner operations through MQTT.
- Loading branch information
Showing
9 changed files
with
249 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from types import TracebackType | ||
from typing import Optional, Type | ||
|
||
import rich.progress | ||
|
||
|
||
class Progress: | ||
def __init__(self): | ||
self._progress = rich.progress.Progress() | ||
self._task: rich.progress.TaskID | None = None | ||
|
||
def __call__(self, total_size: int, bytes_read: int): | ||
if self._task is None: | ||
self._task = self._progress.add_task("Downloading", total=total_size) | ||
self._progress.update(self._task, completed=bytes_read) | ||
|
||
def __enter__(self): | ||
self._progress.__enter__() | ||
return self | ||
|
||
def __exit__( | ||
self, | ||
exc_type: Optional[Type[BaseException]], | ||
exc_val: Optional[BaseException], | ||
exc_tb: Optional[TracebackType], | ||
): | ||
self._progress.__exit__(exc_type, exc_val, exc_tb) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from threading import Event, Thread | ||
from typing import Callable | ||
|
||
|
||
class Alarm: | ||
def __init__(self, seconds: float, callback: Callable[[], None]): | ||
self._seconds = seconds | ||
self._callback = callback | ||
self._stop = Event() | ||
self._thread = Thread(target=self._loop) | ||
|
||
def _loop(self): | ||
while not self._stop.wait(timeout=self._seconds): | ||
self._callback() | ||
|
||
def start(self): | ||
self._stop.clear() | ||
self._thread.start() | ||
|
||
def stop(self): | ||
self._stop.set() | ||
self._thread.join() | ||
|
||
def __enter__(self): | ||
self.start() | ||
return self | ||
|
||
def __exit__(self, *_): | ||
self.stop() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from queue import Queue | ||
from threading import Thread | ||
from typing import Callable | ||
|
||
callback_type = Callable[[], None] | ||
|
||
|
||
class Background: | ||
def __init__(self): | ||
self._queue: Queue[callback_type | None] = Queue() | ||
self._thread = Thread(target=self.loop) | ||
|
||
def fire(self, callback: callback_type): | ||
self._queue.put(callback) | ||
|
||
def shutdown(self): | ||
self._queue.put(None) | ||
|
||
def loop(self): | ||
while True: | ||
callback = self._queue.get() | ||
if callback is None: | ||
break | ||
callback() | ||
self._queue.task_done() | ||
|
||
def start(self): | ||
assert self._queue.empty() | ||
self._thread.start() | ||
|
||
def stop(self): | ||
self.shutdown() | ||
self._thread.join() | ||
|
||
def __enter__(self): | ||
self.start() | ||
return self | ||
|
||
def __exit__(self, *_): | ||
self.stop() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import Optional | ||
|
||
import typer | ||
from deciphon_poster.poster import Poster | ||
|
||
from deciphon_worker.presser import presser_loop | ||
from deciphon_worker.scanner import scanner_loop | ||
from deciphon_worker.url import http_url | ||
|
||
app = typer.Typer( | ||
add_completion=False, | ||
pretty_exceptions_short=True, | ||
pretty_exceptions_show_locals=False, | ||
) | ||
|
||
|
||
@app.command() | ||
def scanner( | ||
sched_url: str, mqtt_host: str, mqtt_port: int, s3_url: Optional[str] = None | ||
): | ||
poster = Poster(http_url(sched_url), s3_url if s3_url is None else http_url(s3_url)) | ||
scanner_loop(poster, mqtt_host, mqtt_port) | ||
|
||
|
||
@app.command() | ||
def presser( | ||
sched_url: str, mqtt_host: str, mqtt_port: int, s3_url: Optional[str] = None | ||
): | ||
poster = Poster(http_url(sched_url), s3_url if s3_url is None else http_url(s3_url)) | ||
presser_loop(poster, mqtt_host, mqtt_port) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import shlex | ||
import subprocess | ||
from pathlib import Path | ||
|
||
import requests | ||
from loguru import logger | ||
from pydantic import HttpUrl | ||
|
||
|
||
def download(url: HttpUrl, dst: Path): | ||
logger.info(f"downloading {url}") | ||
u = url.unicode_string() | ||
try: | ||
cmd = shlex.join(["curl", "--silent", "-L", u]) + " > " + shlex.join([str(dst)]) | ||
subprocess.check_call(cmd, shell=True) | ||
except subprocess.CalledProcessError: | ||
chunk_size = 4_194_304 | ||
with requests.get(u, stream=True) as r: | ||
r.raise_for_status() | ||
with open(dst, "wb") as f: | ||
for chunk in r.iter_content(chunk_size=chunk_size): | ||
f.write(chunk) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from __future__ import annotations | ||
|
||
from functools import partial | ||
from pathlib import Path | ||
from queue import Queue | ||
|
||
from deciphon_core.schema import HMMFile, HMMName | ||
from deciphon_poster.poster import Poster | ||
from deciphon_poster.schema import JobUpdate | ||
from loguru import logger | ||
from paho.mqtt.client import CallbackAPIVersion, Client | ||
|
||
from deciphon_worker.background import Background | ||
from deciphon_worker.download import download | ||
from deciphon_worker.files import atomic_file_creation | ||
from deciphon_worker.models import ScanRequest | ||
from deciphon_worker.scan_thread import ScanThread | ||
|
||
FILE_MODE = 0o640 | ||
TOPIC = "/deciphon.org/scan" | ||
|
||
|
||
def on_connect(client, userdata, flags, reason_code, properties): | ||
logger.info(f"connected to MQTT with result code {reason_code}") | ||
logger.info(f"subscribing to {TOPIC}") | ||
client.subscribe(TOPIC) | ||
|
||
|
||
def on_message(client, userdata, msg): | ||
assert isinstance(msg.payload, bytes) | ||
payload = msg.payload.decode() | ||
logger.info(f"received <{payload}>") | ||
requests: Queue[ScanRequest] = userdata | ||
requests.put(ScanRequest.model_validate_json(payload)) | ||
|
||
|
||
def scanner_hash(hmm: HMMName, multi_hits: bool, hmmer3_compat: bool): | ||
return hash(f"{str(hmm)}_{multi_hits}_{hmmer3_compat}") | ||
|
||
|
||
def process_request( | ||
scans: dict[int, ScanThread], | ||
bg: Background, | ||
poster: Poster, | ||
request: ScanRequest, | ||
): | ||
logger.info(f"processing scan request: {request}") | ||
|
||
hmmfile = Path(request.hmm.name) | ||
dbfile = Path(request.db.name) | ||
|
||
if not hmmfile.exists(): | ||
with atomic_file_creation(hmmfile) as t: | ||
download(poster.download_hmm_url(hmmfile.name), t) | ||
|
||
if not dbfile.exists(): | ||
with atomic_file_creation(dbfile) as t: | ||
download(poster.download_db_url(dbfile.name), t) | ||
|
||
id = scanner_hash(request.hmm, request.multi_hits, request.hmmer3_compat) | ||
if id not in scans: | ||
hmm = HMMFile(path=hmmfile) | ||
scans[id] = ScanThread( | ||
bg, poster, hmm, request.multi_hits, request.hmmer3_compat | ||
) | ||
scans[id].start() | ||
|
||
scans[id].fire(request) | ||
|
||
|
||
def scanner_loop(poster: Poster, mqtt_host: str, mqtt_port: int): | ||
requests: Queue[ScanRequest] = Queue() | ||
scans: dict[int, ScanThread] = dict() | ||
|
||
logger.info(f"connecting to MQTT server (host={mqtt_host}, port={mqtt_port})") | ||
client = Client(CallbackAPIVersion.VERSION2, userdata=requests) | ||
client.on_connect = on_connect | ||
client.on_message = on_message | ||
client.connect(mqtt_host, mqtt_port) | ||
|
||
client.loop_start() | ||
with Background() as bg: | ||
while True: | ||
request = requests.get() | ||
try: | ||
process_request(scans, bg, poster, request) | ||
except Exception as exception: | ||
logger.warning(f"scanning failed: {exception}") | ||
job_update = JobUpdate.fail(request.job_id, str(exception)) | ||
bg.fire(partial(poster.job_patch, job_update)) | ||
finally: | ||
requests.task_done() | ||
client.loop_stop() |