Skip to content

Commit

Permalink
add code
Browse files Browse the repository at this point in the history
  • Loading branch information
brisvag committed Aug 23, 2023
1 parent bde64c7 commit c54dcf7
Show file tree
Hide file tree
Showing 7 changed files with 538 additions and 0 deletions.
172 changes: 172 additions & 0 deletions src/waretomo/_aretomo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import os
import shutil
import subprocess
from contextlib import contextmanager
from pathlib import Path
from queue import Queue
from time import sleep

import GPUtil
from rich import print

from .threaded import run_threaded


@contextmanager
def _cd(dir):
prev = os.getcwd()
os.chdir(dir)
yield
os.chdir(prev)


def _aretomo(
input_,
rawtlt,
aln,
xf,
output,
full_ts_name,
suffix="",
cmd="AreTomo",
tilt_axis=0,
patches=None,
roi_file=None,
tilt_corr=True,
thickness_align=1200,
thickness_recon=0,
binning=4,
px_size=0,
kv=0,
dose=0,
cs=0,
defocus=0,
reconstruct=False,
gpu_queue=None,
dry_run=False,
verbose=False,
overwrite=False,
):
# cwd dance is necessary cause aretomo messes up paths otherwise
# need to use os.path.relpath cause pathlib cannot handle non-subpath relative paths
# https://stackoverflow.com/questions/38083555/using-pathlibs-relative-to-for-directories-on-the-same-level
cwd = output.parent.absolute()
input_ = Path(os.path.relpath(input_, cwd))
rawtlt = Path(os.path.relpath(rawtlt, cwd))
aln = Path(os.path.relpath(aln, cwd))
xf = Path(os.path.relpath(xf, cwd))
output = Path(os.path.relpath(output, cwd))
if not reconstruct:
output = output.with_stem(output.stem + "_aligned").with_suffix(".st")
# LogFile is broken, so we do it ourselves
log = output.with_suffix(".aretomolog")
with _cd(cwd):
if not overwrite and output.exists():
raise FileExistsError(output)

# only one job per gpu
gpu = 0 if gpu_queue is None else gpu_queue.get()

options = {
"InMrc": input_,
"OutMrc": output,
# 'LogFile': input_.with_suffix('.log').relative_to(cwd), # currently broken
"OutBin": binning,
"Gpu": gpu,
"DarkTol": 0,
}

if reconstruct:
options.update(
{
"AlnFile": aln,
"VolZ": thickness_recon,
"PixSize": px_size,
"Kv": kv,
"ImgDose": dose,
"Cs": cs,
"Defoc": defocus,
"FlipVol": 1,
"WBP": 1,
}
)
else:
options.update(
{
"AngFile": rawtlt,
"AlignZ": thickness_align,
"TiltCor": int(tilt_corr),
"VolZ": 0,
"TiltAxis": f"{tilt_axis} 1" if tilt_axis is not None else "0 1",
"OutImod": 2,
}
)
if roi_file is not None:
options["RoiFile"] = roi_file
if patches is not None:
options["Patch"] = f"{patches} {patches}"

# run aretomo with basic settings
aretomo_cmd = f"{cmd} {' '.join(f'-{k} {v}' for k, v in options.items())}"

if verbose:
print(aretomo_cmd)
if not reconstruct:
print(f'mv {xf} {full_ts_name + ".xf"}')

if not dry_run:
with _cd(cwd):
try:
proc = subprocess.run(
aretomo_cmd.split(), capture_output=True, check=False, cwd=cwd
)
finally:
log.write_bytes(proc.stdout + proc.stderr)
if gpu_queue is not None:
gpu_queue.put(gpu)
proc.check_returncode()
if not reconstruct:
# move xf file so warp can see it (needs full ts name + .xf)
shutil.move(xf, full_ts_name + ".xf")
else:
sleep(0.1)
if gpu_queue is not None:
gpu_queue.put(gpu)


def aretomo_batch(
progress, tilt_series, suffix="", label="", cmd="AreTomo", gpus=None, **kwargs
):
if not shutil.which(cmd):
raise FileNotFoundError(f"{cmd} is not available on the system")
if gpus is None:
gpus = [gpu.id for gpu in GPUtil.getGPUs()]
if not gpus:
raise RuntimeError("you need at least one GPU to run AreTomo")
if kwargs.get("verbose"):
print(f"[yellow]Running AreTomo in parallel on {len(gpus)} GPUs.")

# use a queue to hold gpu ids to ensure we only run one job per gpu
gpu_queue = Queue()
for gpu in gpus:
gpu_queue.put(gpu)

partials = []
for ts in tilt_series:
partials.append(
lambda ts=ts: _aretomo(
input_=ts["stack" + suffix] if suffix else ts["fix"],
rawtlt=ts["rawtlt"],
aln=ts["aln"],
xf=ts["xf"],
roi_file=ts["roi"],
output=ts["recon" + suffix],
gpu_queue=gpu_queue,
cmd=cmd,
full_ts_name=ts["name"],
**ts["aretomo_kwargs"],
**kwargs,
)
)

run_threaded(progress, partials, label=label, max_workers=len(gpus), **kwargs)
37 changes: 37 additions & 0 deletions src/waretomo/_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import shutil
import subprocess
from time import sleep

from rich import print

from .threaded import run_threaded


def _ccderaser(
input, output, cmd="ccderaser", dry_run=False, verbose=False, overwrite=False
):
if not overwrite and output.exists():
raise FileExistsError(output)
# run ccderaser, defaults from etomo
ccderaser_cmd = (
f"{cmd} -input {input} -output {output} -find -peak 8.0 "
f"-diff 6.0 -big 19. -giant 12. -large 8. -grow 4. -edge 4"
)

if verbose:
print(ccderaser_cmd)
if not dry_run:
subprocess.run(ccderaser_cmd.split(), capture_output=True, check=True)
else:
sleep(0.1)


def fix_batch(progress, tilt_series, cmd="ccderaser", **kwargs):
if not shutil.which(cmd):
raise FileNotFoundError(f"{cmd} is not available on the system")

partials = [
lambda ts=ts: _ccderaser(ts["stack"], ts["fix"], cmd=cmd, **kwargs)
for ts in tilt_series
]
run_threaded(progress, partials, label="Fixing", **kwargs)
55 changes: 55 additions & 0 deletions src/waretomo/_fix_mdoc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pandas as pd
from mdocfile.mdoc import Mdoc

from .threaded import run_threaded


def _tilt_mdoc(
mdoc_file, tlt_file, skipped_tilts, verbose=False, dry_run=False, overwrite=False
):
if verbose:
print(f"Tilting mdoc: {mdoc_file}")
print(f"using: {tlt_file}")

output = mdoc_file.parent / "mdoc_tilted" / mdoc_file.name

if not overwrite and output.exists():
raise FileExistsError(output)

if not dry_run:
mdoc = Mdoc.from_file(mdoc_file)
new_angles = iter(pd.read_csv(tlt_file, header=None)[0])
for i, section in enumerate(mdoc.section_data):
if i in skipped_tilts:
continue
try:
section.TiltAngle = next(new_angles)
except StopIteration as e:
raise RuntimeError(
f"not enough tilts generated by aretomo in {tlt_file}"
) from e

try:
next(new_angles)
except StopIteration:
pass
else:
raise RuntimeError(f"too many tilts generated by aretomo in {tlt_file}")

with open(output, "w+" if overwrite else "w") as f:
f.write(mdoc.to_string())


def tilt_mdocs_batch(progress, tilt_series, **kwargs):
partials = []
for ts in tilt_series:
partials.append(
lambda ts=ts: _tilt_mdoc(
mdoc_file=ts["mdoc"],
tlt_file=ts["tlt"],
skipped_tilts=ts["skipped_tilts"],
**kwargs,
)
)

run_threaded(progress, partials, label="Creating tilted mdocs...", **kwargs)
128 changes: 128 additions & 0 deletions src/waretomo/_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from pathlib import Path, PureWindowsPath
from xml.etree import ElementTree

import mdocfile


def parse_data(
progress, warp_dir, mdoc_dir, output_dir, roi_dir, just=(), exclude=(), train=False
):
imod_dir = warp_dir / "imod"
if not imod_dir.exists():
raise FileNotFoundError("warp directory does not have an `imod` subdirectory")

if just:
mdocs = [
p
for ts_name in just
if (p := (Path(mdoc_dir) / (ts_name + ".mdoc"))).exists()
]
else:
mdocs = sorted(Path(mdoc_dir).glob("*.mdoc"))

if not mdocs:
raise FileNotFoundError("could not find any mdoc files")

odd_dir = warp_dir / "average" / "odd"
even_dir = warp_dir / "average" / "even"

tilt_series = []
tilt_series_excluded = []
tilt_series_unprocessed = []

for mdoc in progress.track(mdocs, description="Reading mdocs..."):
ts_name = mdoc.stem
stack = imod_dir / ts_name / (ts_name + ".st")

if ts_name in exclude:
tilt_series_excluded.append(ts_name)
continue
# skip if not preprocessed in warp
if not stack.exists():
tilt_series_unprocessed.append(ts_name)
continue

df = mdocfile.read(mdoc)

# extract even/odd paths
tilts = [warp_dir / PureWindowsPath(tilt).name for tilt in df.SubFramePath]
skipped_tilts = []
odd = []
even = []
valid_xml = None
for i, tilt in enumerate(tilts):
xml = ElementTree.parse(tilt.with_suffix(".xml")).getroot()
if xml.attrib["UnselectManual"] == "True":
skipped_tilts.append(i)
else:
valid_xml = xml
odd.append(odd_dir / (tilt.stem + ".mrc"))
even.append(even_dir / (tilt.stem + ".mrc"))

if valid_xml is None:
tilt_series_unprocessed.append(ts_name)
continue

if train:
for img in odd + even:
if not img.exists():
raise FileNotFoundError(img)

# extract metadata from warp xmls
# (we assume the last xml has the same data as the others)
for param in valid_xml.find("OptionsCTF"):
if param.get("Name") == "BinTimes":
binning = float(param.get("Value"))
elif param.get("Name") == "Voltage":
kv = int(param.get("Value"))
elif param.get("Name") == "Cs":
cs = float(param.get("Value"))
for param in xml.find("CTF"):
if param.get("Name") == "Defocus":
defocus = (
float(param.get("Value")) * 1e4
) # defocus for aretomo is in Angstrom

if roi_dir is not None:
roi_files = list(roi_dir.glob(f"{ts_name}*"))
if len(roi_files) == 1:
roi_file = roi_files[0]
else:
roi_file = None
else:
roi_file = None

ts_fixed = ts_name + "_fix"
ts_stripped = ts_name.split(".")[0]
alignment_result_dir = output_dir / (ts_stripped + "_Imod")

tilt_series.append(
{
"name": ts_name,
"stack": stack,
"rawtlt": stack.with_suffix(".rawtlt"),
"fix": output_dir / (ts_fixed + ".st"),
"aln": output_dir / (ts_stripped + ".aln"),
"xf": alignment_result_dir / (ts_stripped + ".xf"),
"tlt": alignment_result_dir / (ts_stripped + ".tlt"),
"skipped_tilts": skipped_tilts,
"mdoc": mdoc,
"roi": roi_file,
"odd": odd,
"even": even,
"stack_odd": output_dir / (ts_name + "_odd.st"),
"stack_even": output_dir / (ts_name + "_even.st"),
"recon_odd": output_dir / "odd" / (ts_name + ".mrc"),
"recon_even": output_dir / "even" / (ts_name + ".mrc"),
"recon": output_dir / (ts_name + ".mrc"),
"aretomo_kwargs": {
"dose": df.ExposureDose[0],
"px_size": df.PixelSpacing[0] * 2**binning,
"cs": cs,
"kv": kv,
"defocus": defocus,
},
}
)

return tilt_series, tilt_series_excluded, tilt_series_unprocessed
Loading

0 comments on commit c54dcf7

Please sign in to comment.