From 687678b4f7311620a84f1d31c54dace48a093381 Mon Sep 17 00:00:00 2001 From: Lorenzo Gaifas Date: Mon, 28 Aug 2023 15:50:04 +0200 Subject: [PATCH] fix topaz stuff! --- pyproject.toml | 4 +-- src/waretomo/_topaz.py | 74 ++++++++++++++++++++++++------------------ src/waretomo/main.py | 9 +++-- 3 files changed, 51 insertions(+), 36 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5991895..ada3de2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,11 +28,11 @@ dynamic = ["version"] dependencies = [ "GPUtil", "click", - "mdocfile", + "mdocfile==0.1.0", "pandas", "rich", - "topaz-em", "sh", + "topaz-em==0.2.5", ] # extras diff --git a/src/waretomo/_topaz.py b/src/waretomo/_topaz.py index 4445c78..a389bc3 100644 --- a/src/waretomo/_topaz.py +++ b/src/waretomo/_topaz.py @@ -12,18 +12,22 @@ def _run_and_update_progress(progress, task, func, *args, **kwargs): std = io.StringIO() + + def _run_with_redirection(): + with contextlib.redirect_stderr(std): + func(*args, **kwargs) + with futures.ThreadPoolExecutor(1) as executor: - with contextlib.redirect_stderr(std), contextlib.redirect_stdout(std): - job = executor.submit(func, *args, **kwargs) + job = executor.submit(_run_with_redirection) - last_read_pos = 0 - while not job.done(): - time.sleep(0.5) - std.seek(last_read_pos) - last = std.read() - if match := re.search(r"(\d+\.\d+)%", last): - last_read_pos = std.tell() - progress.update(task, completed=float(match.group(1))) + last_read_pos = 0 + while not job.done(): + time.sleep(0.5) + std.seek(last_read_pos) + last = std.read() + if match := re.search(r"(\d+\.\d+)%", last): + last_read_pos = std.tell() + progress.update(task, completed=float(match.group(1)), refresh=True) try: return job.result() @@ -46,34 +50,16 @@ def topaz_batch( train=False, tile_size=32, patch_size=32, + gpus=None, dry_run=False, verbose=False, overwrite=False, ): - set_num_threads(0) - if train: - task = progress.add_task(description="Training...") - model, _ = _run_and_update_progress( - progress, - task, - train_model, - even_path=even, - odd_path=odd, - save_prefix=str(outdir / "trained_models" / model_name), - save_interval=10, - device=-2, - tilesize=patch_size, - base_kernel_width=11, - num_workers=multiprocessing.cpu_count(), - ) - else: - model = load_model(model_name, base_kernel_width=11) - model.eval() - model, use_cuda, num_devices = set_device(model, -2) - inputs = [ts["recon"] for ts in tilt_series] if verbose: + if train: + print(f"training model: '{model_name}' with inputs '{even}' and '{odd}'") if len(inputs) > 2: print(f"denoising: [{inputs[0]} [...] {inputs[-1]}]") else: @@ -81,6 +67,32 @@ def topaz_batch( print(f"output: {outdir}") if not dry_run: + set_num_threads(0) + std = io.StringIO() + if train: + task = progress.add_task(description="Training...") + model, _ = _run_and_update_progress( + progress, + task, + train_model, + even_path=even, + odd_path=odd, + save_prefix=str(outdir / "trained_models" / model_name), + save_interval=10, + device=-2 + if gpus is None + else next(iter(gpus)), # TODO: actually run with multiple gpus? + tilesize=patch_size, + base_kernel_width=11, + num_workers=multiprocessing.cpu_count(), + ) + else: + with contextlib.redirect_stdout(std): + model = load_model(model_name, base_kernel_width=11) + with contextlib.redirect_stdout(std): + model.eval() + model, use_cuda, num_devices = set_device(model, -2) + for path in progress.track(inputs, description="Denoising..."): subtask = progress.add_task(description=path.name) _run_and_update_progress( diff --git a/src/waretomo/main.py b/src/waretomo/main.py index 1a41a72..1929bff 100644 --- a/src/waretomo/main.py +++ b/src/waretomo/main.py @@ -255,9 +255,10 @@ def cli( "verbose": verbose, } - topaz_opts = { + topaz_kwargs = { "train": train, "model_name": topaz_model, + "gpus": gpus, "tile_size": topaz_tile_size, "patch_size": topaz_patch_size, } @@ -289,7 +290,9 @@ def cli( aretomo_opts_log = "".join( f'{nl}{" " * 12}- {k}: {v}' for k, v in aretomo_kwargs.items() ) - topaz_log = "".join(f'{nl}{" " * 12}- {k}: {v}' for k, v in topaz_opts.items()) + topaz_log = "".join( + f'{nl}{" " * 12}- {k}: {v}' for k, v in topaz_kwargs.items() + ) print( Panel( @@ -402,6 +405,6 @@ def cli( outdir=outdir_denoised, even=str(output_dir / "even"), odd=str(output_dir / "odd"), - **topaz_opts, + **topaz_kwargs, **meta_kwargs, )