Skip to content

Commit

Permalink
fix topaz stuff!
Browse files Browse the repository at this point in the history
  • Loading branch information
brisvag committed Aug 28, 2023
1 parent ec3ebda commit 687678b
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 36 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ dynamic = ["version"]
dependencies = [
"GPUtil",
"click",
"mdocfile",
"mdocfile==0.1.0",
"pandas",
"rich",
"topaz-em",
"sh",
"topaz-em==0.2.5",
]

# extras
Expand Down
74 changes: 43 additions & 31 deletions src/waretomo/_topaz.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,22 @@

def _run_and_update_progress(progress, task, func, *args, **kwargs):
std = io.StringIO()

def _run_with_redirection():
with contextlib.redirect_stderr(std):
func(*args, **kwargs)

with futures.ThreadPoolExecutor(1) as executor:
with contextlib.redirect_stderr(std), contextlib.redirect_stdout(std):
job = executor.submit(func, *args, **kwargs)
job = executor.submit(_run_with_redirection)

last_read_pos = 0
while not job.done():
time.sleep(0.5)
std.seek(last_read_pos)
last = std.read()
if match := re.search(r"(\d+\.\d+)%", last):
last_read_pos = std.tell()
progress.update(task, completed=float(match.group(1)))
last_read_pos = 0
while not job.done():
time.sleep(0.5)
std.seek(last_read_pos)
last = std.read()
if match := re.search(r"(\d+\.\d+)%", last):
last_read_pos = std.tell()
progress.update(task, completed=float(match.group(1)), refresh=True)

try:
return job.result()
Expand All @@ -46,41 +50,49 @@ def topaz_batch(
train=False,
tile_size=32,
patch_size=32,
gpus=None,
dry_run=False,
verbose=False,
overwrite=False,
):
set_num_threads(0)
if train:
task = progress.add_task(description="Training...")
model, _ = _run_and_update_progress(
progress,
task,
train_model,
even_path=even,
odd_path=odd,
save_prefix=str(outdir / "trained_models" / model_name),
save_interval=10,
device=-2,
tilesize=patch_size,
base_kernel_width=11,
num_workers=multiprocessing.cpu_count(),
)
else:
model = load_model(model_name, base_kernel_width=11)
model.eval()
model, use_cuda, num_devices = set_device(model, -2)

inputs = [ts["recon"] for ts in tilt_series]

if verbose:
if train:
print(f"training model: '{model_name}' with inputs '{even}' and '{odd}'")
if len(inputs) > 2:
print(f"denoising: [{inputs[0]} [...] {inputs[-1]}]")
else:
print(f"denoising: {inputs}")
print(f"output: {outdir}")

if not dry_run:
set_num_threads(0)
std = io.StringIO()
if train:
task = progress.add_task(description="Training...")
model, _ = _run_and_update_progress(
progress,
task,
train_model,
even_path=even,
odd_path=odd,
save_prefix=str(outdir / "trained_models" / model_name),
save_interval=10,
device=-2
if gpus is None
else next(iter(gpus)), # TODO: actually run with multiple gpus?
tilesize=patch_size,
base_kernel_width=11,
num_workers=multiprocessing.cpu_count(),
)
else:
with contextlib.redirect_stdout(std):
model = load_model(model_name, base_kernel_width=11)
with contextlib.redirect_stdout(std):
model.eval()
model, use_cuda, num_devices = set_device(model, -2)

for path in progress.track(inputs, description="Denoising..."):
subtask = progress.add_task(description=path.name)
_run_and_update_progress(
Expand Down
9 changes: 6 additions & 3 deletions src/waretomo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,10 @@ def cli(
"verbose": verbose,
}

topaz_opts = {
topaz_kwargs = {
"train": train,
"model_name": topaz_model,
"gpus": gpus,
"tile_size": topaz_tile_size,
"patch_size": topaz_patch_size,
}
Expand Down Expand Up @@ -289,7 +290,9 @@ def cli(
aretomo_opts_log = "".join(
f'{nl}{" " * 12}- {k}: {v}' for k, v in aretomo_kwargs.items()
)
topaz_log = "".join(f'{nl}{" " * 12}- {k}: {v}' for k, v in topaz_opts.items())
topaz_log = "".join(
f'{nl}{" " * 12}- {k}: {v}' for k, v in topaz_kwargs.items()
)

print(
Panel(
Expand Down Expand Up @@ -402,6 +405,6 @@ def cli(
outdir=outdir_denoised,
even=str(output_dir / "even"),
odd=str(output_dir / "odd"),
**topaz_opts,
**topaz_kwargs,
**meta_kwargs,
)

0 comments on commit 687678b

Please sign in to comment.