Skip to content

Commit

Permalink
topaz working?
Browse files Browse the repository at this point in the history
  • Loading branch information
brisvag committed Aug 25, 2023
1 parent f522f60 commit 366be72
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 48 deletions.
109 changes: 65 additions & 44 deletions src/waretomo/_topaz.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import io
import multiprocessing
import re
import time
from concurrent import futures
Expand All @@ -8,34 +9,74 @@
import torch
import torch.nn as nn
from rich import print
from topaz.commands.denoise3d import denoise
from topaz.commands.denoise3d import denoise, train_model
from topaz.denoise import UDenoiseNet3D
from topaz.torch import set_num_threads


def topaz_batch_train(progress):
pass
def _run_and_update_progress(progress, task, func, *args, **kwargs):
std = io.StringIO()
with futures.ThreadPoolExecutor(1) as executor:
with contextlib.redirect_stderr(std), contextlib.redirect_stdout(std):
job = executor.submit(func, *args, **kwargs)

last_read_pos = 0
while not job.done():
time.sleep(0.5)
std.seek(last_read_pos)
last = std.read()
if match := re.search(r"(\d+\.\d+)%", last):
last_read_pos = std.tell()
progress.update(task, completed=float(match.group(1)))

try:
return job.result()
except RuntimeError as e:
if "CUDA out of memory." in e.args[0]:
raise RuntimeError(
"Not enough GPU memory. "
"Try to lower --topaz-tile-size or --topaz-patch-size"
) from e
raise


def topaz_batch(
progress,
tilt_series,
outdir,
even,
odd,
train=False,
tile_size=32,
patch_size=32,
dry_run=False,
verbose=False,
overwrite=False,
):
set_num_threads(0)
model = UDenoiseNet3D(base_width=7)
f = pkg_resources.resource_stream(
"topaz", "pretrained/denoise/unet-3d-10a-v0.2.4.sav"
)
state_dict = torch.load(f)
model.load_state_dict(state_dict)
model = nn.DataParallel(model)
model.cuda()
if train:
task = progress.add_task(description="Training...")
model, _ = _run_and_update_progress(
progress,
task,
train_model,
even_path=even,
odd_path=odd,
save_prefix=str(outdir / "trained_model"),
save_interval=10,
device=-2,
tilesize=patch_size,
num_workers=multiprocessing.cpu_count(),
)
else:
model = UDenoiseNet3D(base_width=7)
f = pkg_resources.resource_stream(
"topaz", "pretrained/denoise/unet-3d-10a-v0.2.4.sav"
)
state_dict = torch.load(f)
model.load_state_dict(state_dict)
model = nn.DataParallel(model)
model.cuda()

inputs = [ts["recon"] for ts in tilt_series]

Expand All @@ -49,36 +90,16 @@ def topaz_batch(
if not dry_run:
for path in progress.track(inputs, description="Denoising..."):
subtask = progress.add_task(description=path.name)
stderr = io.StringIO()
with contextlib.redirect_stderr(stderr):
with futures.ThreadPoolExecutor(1) as executor:
job = executor.submit(
denoise,
model=model,
path=path,
outdir=str(outdir),
batch_size=torch.cuda.device_count(),
patch_size=patch_size,
padding=patch_size // 2,
suffix="",
)

last_read_pos = 0
while not job.done():
time.sleep(0.5)
stderr.seek(last_read_pos)
last = stderr.read()
if match := re.search(r"(\d+.\d+)%", last):
last_read_pos = stderr.tell()
progress.update(subtask, completed=float(match.group(1)))

progress.update(subtask, visible=False)

try:
job.result()
except RuntimeError as e:
if "CUDA out of memory." in e.args[0]:
raise RuntimeError(
"Not enough GPU memory. Try a lower --topaz-patch-size"
) from e
raise
_run_and_update_progress(
progress,
subtask,
denoise,
model=model,
path=path,
outdir=str(outdir),
batch_size=torch.cuda.device_count(),
patch_size=patch_size,
padding=patch_size // 2,
suffix="",
)
progress.update(subtask, visible=False)
24 changes: 20 additions & 4 deletions src/waretomo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def __str__(self):
@click.option(
"--train", is_flag=True, default=False, help="whether to train a new denosing model"
)
@click.option(
"--topaz-tile-size",
type=int,
default=64,
help="tile size for training topaz model.",
)
@click.option(
"--topaz-patch-size",
type=int,
Expand Down Expand Up @@ -149,6 +155,7 @@ def cli(
roi_dir,
overwrite,
train,
topaz_tile_size,
topaz_patch_size,
start_from,
stop_at,
Expand Down Expand Up @@ -326,7 +333,8 @@ def cli(
for half in ("even", "odd"):
if verbose:
print(f"\n[green]Reconstructing {half} tomograms for deonoising...")
(output_dir / half).mkdir(parents=True, exist_ok=True)
half_dir = output_dir / half
half_dir.mkdir(parents=True, exist_ok=True)
aretomo_batch(
progress,
tilt_series,
Expand All @@ -336,19 +344,27 @@ def cli(
**aretomo_kwargs,
**meta_kwargs,
)
# remove leftovers from aretomo otherwise topaz dies later
for f in half_dir.glob("*_projX?.mrc"):
f.unlink(missing_ok=True)
for f in half_dir.glob("*.aretomolog"):
f.unlink(missing_ok=True)

if steps["denoise"]:
from ._topaz import topaz_batch

if verbose:
print("\n[green]Denoising tomograms...")
outdir = output_dir / "denoised"
outdir.mkdir(parents=True, exist_ok=True)
outdir_denoised = output_dir / "denoised"
outdir_denoised.mkdir(parents=True, exist_ok=True)
topaz_batch(
progress,
tilt_series,
outdir=outdir,
outdir=outdir_denoised,
even=str(output_dir / "even"),
odd=str(output_dir / "odd"),
train=train,
tile_size=topaz_tile_size,
patch_size=topaz_patch_size,
**meta_kwargs,
)

0 comments on commit 366be72

Please sign in to comment.