From ee8bd6efdde25963d0d824ab924decfaec62e120 Mon Sep 17 00:00:00 2001 From: Ritwik Gupta Date: Wed, 31 Jan 2024 08:18:48 -0500 Subject: [PATCH] Add support for parallel syncing, address #21 [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Fix pre-commit issues Mypy issues --- src/wandb_osh/cli.py | 13 +++++++-- src/wandb_osh/syncer.py | 60 +++++++++++++++++++++++++++-------------- tests/test_syncer.py | 6 ++--- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/src/wandb_osh/cli.py b/src/wandb_osh/cli.py index 344ceec..07f5afb 100644 --- a/src/wandb_osh/cli.py +++ b/src/wandb_osh/cli.py @@ -30,6 +30,12 @@ def _get_parser() -> ArgumentParser: type=int, help="Timeout for wandb sync. If <=0, no timeout.", ) + parser.add_argument( + "--num-workers", + default=1, + type=int, + help="Number of parallel syncs to run at a time.", + ) parser.add_argument( "wandb_options", nargs="*", @@ -44,9 +50,12 @@ def main(argv=None) -> None: parser = _get_parser() args = parser.parse_args(argv) wandb_osh = WandbSyncer( - command_dir=args.command_dir, wait=args.wait, wandb_options=args.wandb_options + command_dir=args.command_dir, + wait=args.wait, + wandb_options=args.wandb_options, + num_workers=args.num_workers, ) - wandb_osh.loop() + wandb_osh.start() if __name__ == "__main__": diff --git a/src/wandb_osh/syncer.py b/src/wandb_osh/syncer.py index 1e05976..1913a8e 100644 --- a/src/wandb_osh/syncer.py +++ b/src/wandb_osh/syncer.py @@ -3,8 +3,10 @@ import os import subprocess import time +from multiprocessing import Process, Queue from os import PathLike from pathlib import Path +from queue import Empty from wandb_osh import __version__ from wandb_osh.config import _command_dir_default @@ -19,6 +21,7 @@ def __init__( wandb_options: list[str] | None = None, *, timeout: int | float = 120, + num_workers: int = 1, ): """Class for interpreting command files and triggering `wandb sync`. @@ -35,6 +38,25 @@ def __init__( self.wait = wait self.wandb_options = wandb_options self._timeout = timeout + self.num_workers = num_workers + self.target_queue: Queue = Queue() + self.workers: list[Process] = [] + + def start(self) -> None: + """Start directory watcher process and sync workers + + Args: + None + """ + watcher = Process(target=self.dir_watcher) + watcher.start() + + self.command_dir.mkdir(parents=True, exist_ok=True) + + for _ in range(self.num_workers): + p = Process(target=self.worker) + self.workers.append(p) + p.start() def sync(self, dir: PathLike) -> None: """Sync a directory. Thin wrapper around the `sync_dir` function. @@ -44,7 +66,7 @@ def sync(self, dir: PathLike) -> None: """ sync_dir(dir, options=self.wandb_options, timeout=self._timeout) - def loop(self) -> None: + def dir_watcher(self) -> None: """Read command files and trigger syncing""" logger.info( "wandb-osh v%s, starting to watch %s", __version__, self.command_dir @@ -52,11 +74,8 @@ def loop(self) -> None: while True: start_time = time.time() self.command_dir.mkdir(parents=True, exist_ok=True) - command_files = [] - targets = [] for command_file in self.command_dir.glob("*.command"): target = Path(command_file.read_text()) - command_files.append(command_file) if not target.is_dir(): logger.error( "Command file %s points to non-existing directory %s", @@ -64,24 +83,25 @@ def loop(self) -> None: target, ) continue - targets.append(target) - for target in set(targets): - logger.info("Syncing %s...", target) - try: - self.sync(target) - except subprocess.TimeoutExpired: - # try again later - logger.warning("Syncing %s timed out. Trying later.", target) - from wandb_osh.hooks import TriggerWandbSyncHook - - TriggerWandbSyncHook(self.command_dir)(target) - time.sleep(0.25) - for cf in command_files: + self.target_queue.put((command_file, target)) + time.sleep(max(0.0, (time.time() - start_time) - self.wait)) + + def worker(self) -> None: + while True: + try: + cf, target = self.target_queue.get(timeout=self._timeout) + self.sync(target) + time.sleep(0.25) if cf.is_file(): cf.unlink() - if "PYTEST_CURRENT_TEST" in os.environ: - break - time.sleep(max(0.0, (time.time() - start_time) - self.wait)) + if "PYTEST_CURRENT_TEST" in os.environ: + break + except Empty: + # try again later + logger.warning("Syncing %s timed out. Trying later.", target) + from wandb_osh.hooks import TriggerWandbSyncHook + + TriggerWandbSyncHook(self.command_dir)(target) def sync_dir( diff --git a/tests/test_syncer.py b/tests/test_syncer.py index c15af48..030f85c 100644 --- a/tests/test_syncer.py +++ b/tests/test_syncer.py @@ -17,13 +17,13 @@ def test_wandb_syncer(tmp_path, caplog): target = tmp_path / "test" / "123" (tmp_path / "123.command").write_text(str(target.resolve())) with caplog.at_level(logging.WARNING): - ws.loop() + ws.start() assert "points to non-existing directory" in caplog.text caplog.clear() (tmp_path / "123.command").write_text(str(target.resolve())) target.mkdir(parents=True) with caplog.at_level(logging.DEBUG): - ws.loop() + ws.start() assert f"Command would be: wandb sync . in {target.resolve()}" in caplog.text set_log_level() @@ -38,5 +38,5 @@ def test_wandb_sync_timeout(tmp_path, caplog): (tmp_path / "123.command").write_text(str(target.resolve())) target.mkdir(parents=True) with caplog.at_level(logging.DEBUG): - ws.loop() + ws.start() assert "timed out. Trying later." in caplog.text