Skip to content

Commit

Permalink
Add progress tracking and new worker features 🚀
Browse files Browse the repository at this point in the history
- Introduced `Progress` class in `deciphonctl/cli.py` for real-time upload tracking.
- Updated `poster.upload` calls to utilize progress feedback with `Progress`.
- Bumped `deciphon-poster` version in `pyproject.toml` to `^0.2.1` to ensure compatibility.
- Added new module `progress.py` to handle upload progress tracking using `rich.progress`.
- Set up foundational worker architecture with new modules: `alarm.py`, `background.py`, `cli.py`, `download.py`, and `scanner.py`.

This upgrade improves user experience by providing detailed progress feedback during uploads, ensuring seamless file transfers and better usability. The new worker modules lay the groundwork for enhanced background task handling, efficient file downloads, and real-time scanner operations through MQTT.
  • Loading branch information
horta committed Oct 24, 2024
1 parent 177e1e4 commit d836e2b
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 3 deletions.
9 changes: 7 additions & 2 deletions control/deciphonctl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from deciphonctl.catch_validation import catch_validation
from deciphonctl.display_exception import display_exception
from deciphonctl.log_level import LogLevel
from deciphonctl.progress import Progress
from deciphonctl.settings import (
Settings,
SettingsFields,
Expand Down Expand Up @@ -134,7 +135,9 @@ def config_get(option: CFGOPT):
def hmm_add(hmmfile: HMMFILE, gencode: GENCODE, epsilon: EPSILON = 0.01):
settings = Settings()
poster = Poster(settings.sched_url, settings.s3_url)
poster.upload(hmmfile, poster.upload_hmm_post(hmmfile.name))
upload_post = poster.upload_hmm_post(hmmfile.name)
with Progress() as progress:
poster.upload(hmmfile, upload_post, progress)
poster.hmm_post(HMMName(name=hmmfile.name), gencode, epsilon)


Expand Down Expand Up @@ -162,7 +165,9 @@ def hmm_ls():
def db_add(dbfile: DBFILE, gencode: GENCODE, epsilon: EPSILON = 0.01):
settings = Settings()
poster = Poster(settings.sched_url, settings.s3_url)
poster.upload(dbfile, poster.upload_db_post(dbfile.name))
upload_post = poster.upload_db_post(dbfile.name)
with Progress() as progress:
poster.upload(dbfile, upload_post, progress)
poster.db_post(DBName(name=dbfile.name))


Expand Down
27 changes: 27 additions & 0 deletions control/deciphonctl/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from types import TracebackType
from typing import Optional, Type

import rich.progress


class Progress:
def __init__(self):
self._progress = rich.progress.Progress()
self._task: rich.progress.TaskID | None = None

def __call__(self, total_size: int, bytes_read: int):
if self._task is None:
self._task = self._progress.add_task("Downloading", total=total_size)
self._progress.update(self._task, completed=bytes_read)

def __enter__(self):
self._progress.__enter__()
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
):
self._progress.__exit__(exc_type, exc_val, exc_tb)
2 changes: 1 addition & 1 deletion control/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ loguru = "^0.7"
requests-toolbelt = "^1.0"
paho-mqtt = "^2.1"
deciphon = "^1.0.0"
deciphon-poster = "^0.1.2"
deciphon-poster = "^0.2.1"
xdg-base-dirs = "^6.0.2"
python-dotenv = "^1.0.1"

Expand Down
Empty file.
29 changes: 29 additions & 0 deletions worker/deciphon_worker/alarm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from threading import Event, Thread
from typing import Callable


class Alarm:
def __init__(self, seconds: float, callback: Callable[[], None]):
self._seconds = seconds
self._callback = callback
self._stop = Event()
self._thread = Thread(target=self._loop)

def _loop(self):
while not self._stop.wait(timeout=self._seconds):
self._callback()

def start(self):
self._stop.clear()
self._thread.start()

def stop(self):
self._stop.set()
self._thread.join()

def __enter__(self):
self.start()
return self

def __exit__(self, *_):
self.stop()
40 changes: 40 additions & 0 deletions worker/deciphon_worker/background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from queue import Queue
from threading import Thread
from typing import Callable

callback_type = Callable[[], None]


class Background:
def __init__(self):
self._queue: Queue[callback_type | None] = Queue()
self._thread = Thread(target=self.loop)

def fire(self, callback: callback_type):
self._queue.put(callback)

def shutdown(self):
self._queue.put(None)

def loop(self):
while True:
callback = self._queue.get()
if callback is None:
break
callback()
self._queue.task_done()

def start(self):
assert self._queue.empty()
self._thread.start()

def stop(self):
self.shutdown()
self._thread.join()

def __enter__(self):
self.start()
return self

def __exit__(self, *_):
self.stop()
30 changes: 30 additions & 0 deletions worker/deciphon_worker/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Optional

import typer
from deciphon_poster.poster import Poster

from deciphon_worker.presser import presser_loop
from deciphon_worker.scanner import scanner_loop
from deciphon_worker.url import http_url

app = typer.Typer(
add_completion=False,
pretty_exceptions_short=True,
pretty_exceptions_show_locals=False,
)


@app.command()
def scanner(
sched_url: str, mqtt_host: str, mqtt_port: int, s3_url: Optional[str] = None
):
poster = Poster(http_url(sched_url), s3_url if s3_url is None else http_url(s3_url))
scanner_loop(poster, mqtt_host, mqtt_port)


@app.command()
def presser(
sched_url: str, mqtt_host: str, mqtt_port: int, s3_url: Optional[str] = None
):
poster = Poster(http_url(sched_url), s3_url if s3_url is None else http_url(s3_url))
presser_loop(poster, mqtt_host, mqtt_port)
22 changes: 22 additions & 0 deletions worker/deciphon_worker/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import shlex
import subprocess
from pathlib import Path

import requests
from loguru import logger
from pydantic import HttpUrl


def download(url: HttpUrl, dst: Path):
logger.info(f"downloading {url}")
u = url.unicode_string()
try:
cmd = shlex.join(["curl", "--silent", "-L", u]) + " > " + shlex.join([str(dst)])
subprocess.check_call(cmd, shell=True)
except subprocess.CalledProcessError:
chunk_size = 4_194_304
with requests.get(u, stream=True) as r:
r.raise_for_status()
with open(dst, "wb") as f:
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
93 changes: 93 additions & 0 deletions worker/deciphon_worker/scanner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

from functools import partial
from pathlib import Path
from queue import Queue

from deciphon_core.schema import HMMFile, HMMName
from deciphon_poster.poster import Poster
from deciphon_poster.schema import JobUpdate
from loguru import logger
from paho.mqtt.client import CallbackAPIVersion, Client

from deciphon_worker.background import Background
from deciphon_worker.download import download
from deciphon_worker.files import atomic_file_creation
from deciphon_worker.models import ScanRequest
from deciphon_worker.scan_thread import ScanThread

FILE_MODE = 0o640
TOPIC = "/deciphon.org/scan"


def on_connect(client, userdata, flags, reason_code, properties):
logger.info(f"connected to MQTT with result code {reason_code}")
logger.info(f"subscribing to {TOPIC}")
client.subscribe(TOPIC)


def on_message(client, userdata, msg):
assert isinstance(msg.payload, bytes)
payload = msg.payload.decode()
logger.info(f"received <{payload}>")
requests: Queue[ScanRequest] = userdata
requests.put(ScanRequest.model_validate_json(payload))


def scanner_hash(hmm: HMMName, multi_hits: bool, hmmer3_compat: bool):
return hash(f"{str(hmm)}_{multi_hits}_{hmmer3_compat}")


def process_request(
scans: dict[int, ScanThread],
bg: Background,
poster: Poster,
request: ScanRequest,
):
logger.info(f"processing scan request: {request}")

hmmfile = Path(request.hmm.name)
dbfile = Path(request.db.name)

if not hmmfile.exists():
with atomic_file_creation(hmmfile) as t:
download(poster.download_hmm_url(hmmfile.name), t)

if not dbfile.exists():
with atomic_file_creation(dbfile) as t:
download(poster.download_db_url(dbfile.name), t)

id = scanner_hash(request.hmm, request.multi_hits, request.hmmer3_compat)
if id not in scans:
hmm = HMMFile(path=hmmfile)
scans[id] = ScanThread(
bg, poster, hmm, request.multi_hits, request.hmmer3_compat
)
scans[id].start()

scans[id].fire(request)


def scanner_loop(poster: Poster, mqtt_host: str, mqtt_port: int):
requests: Queue[ScanRequest] = Queue()
scans: dict[int, ScanThread] = dict()

logger.info(f"connecting to MQTT server (host={mqtt_host}, port={mqtt_port})")
client = Client(CallbackAPIVersion.VERSION2, userdata=requests)
client.on_connect = on_connect
client.on_message = on_message
client.connect(mqtt_host, mqtt_port)

client.loop_start()
with Background() as bg:
while True:
request = requests.get()
try:
process_request(scans, bg, poster, request)
except Exception as exception:
logger.warning(f"scanning failed: {exception}")
job_update = JobUpdate.fail(request.job_id, str(exception))
bg.fire(partial(poster.job_patch, job_update))
finally:
requests.task_done()
client.loop_stop()

0 comments on commit d836e2b

Please sign in to comment.