diff --git a/src/processing_steps/0300_volume_matcher.py b/src/processing_steps/0300_volume_matcher.py index a9c7d6a..3dd4c9f 100755 --- a/src/processing_steps/0300_volume_matcher.py +++ b/src/processing_steps/0300_volume_matcher.py @@ -1,3 +1,7 @@ +#! /usr/bin/python3 +''' +Script for matching volumes from the top and bottom of a multi-scan tomogram. +''' import sys sys.path.append(sys.path[0]+"/../") from config.paths import hdf5_root @@ -16,25 +20,26 @@ # - should probably be removed by kernel-operation around outer edge, before matching... ########################### -### Fix shifted volumes ### -########################### + ''' + Find shift that minimizes squared differences with `overlap <= shift <= max_shift`. + + Parameters + ---------- + `voxels_top` : jp.array[Any] + The top region to match. + `voxels_bot` : jp.array[Any] + The bottom region to match. + `overlap` : int + The overlap between the regions. + `max_shift` : int + The maximum shift to consider. + + Returns + ------- + `result` : Tuple[int, float] + The shift that minimizes the squared differences and the corresponding error. + ''' -import h5py, sys, jax, os.path, pathlib, tqdm -import numpy as np -import jax.numpy as jp -import h5py, jax, sys -from PIL import Image -sys.path.append(sys.path[0]+"/../") -from config.paths import hdf5_root -from lib.py.helpers import commandline_args - -verbose = 1 -volume_matched_dir = f"{hdf5_root}/processed/volume_matched" - -def match_region(voxels_top, voxels_bot, overlap, max_shift): - """ - Find shift that minimizes squared differences with overlap <= shift <= max_shift - """ # Shifts smaller than the overlap overlap with shift slice_size = voxels_top.shape[1]*voxels_top.shape[2] # Normalize by number of voxels (to make sums_lt and sums_ge comparable) sums_lt = jp.array( [ jp.sum(((voxels_top[-shift:] - voxels_bot[0:shift])/(shift*slice_size))**2) @@ -53,6 +58,24 @@ def match_region(voxels_top, voxels_bot, overlap, max_shift): def match_all_regions(voxels,crossings,write_image_checks=True): + ''' + Match all regions in a volume. + + Parameters + ---------- + `voxels` : np.array[Any] + The volume to match. + `crossings` : np.array[int] + The crossings between the regions. + `write_image_checks` : bool + Whether to write images to check correctness. + + Returns + ------- + `shifts, errors` : Tuple[np.array[int], np.array[float]] + The shifts that minimize the squared differences and the corresponding errors. + ''' + shifts = np.zeros(len(crossings),dtype=np.int32) errors = np.zeros(len(crossings),dtype=np.float32) match_region_jit = jax.jit(match_region,static_argnums=(2,3)); @@ -91,9 +114,27 @@ def match_all_regions(voxels,crossings,write_image_checks=True): del bot_voxels return shifts,errors -# Copy through the volume matched volume from the original - general interface -# that works equally well for anything indexed like numpy arrays - including cupy, HDF5 and netCDF. def write_matched(voxels_in, voxels_out, crossings, shifts): + ''' + Copy through the volume matched volume from the original. + - general interface that works equally well for anything indexed like numpy arrays - including cupy, HDF5 and netCDF. + + Parameters + ---------- + `voxels_in` : np.array[Any] + The original volume. + `voxels_out` : np.array[Any] + The volume to write to. + `crossings` : np.array[int] + The crossings between the regions. + `shifts` : np.array[int] + The shifts that minimize the squared differences. + + Returns + ------- + `None` + ''' + voxels_out[:crossings[0]] = voxels_in[:crossings[0]] cum_shifts = [0]+list(np.cumsum(shifts)) crossings = list(crossings) + [voxels_in.shape[0]] @@ -105,9 +146,28 @@ def write_matched(voxels_in, voxels_out, crossings, shifts): if verbose >= 1: print(f"Duplicating unmatched part of subvolume {i+1}: voxels_out[{crossings[i]-cum_shifts[i]}:{crossings[i+1]-cum_shifts[i]-shifts[i]}] = voxels_in[{crossings[i]+shifts[i]}:{crossings[i+1]}];") voxels_out[crossings[i]-cum_shifts[i]:crossings[i+1]-cum_shifts[i]-shifts[i]] = voxels_in[crossings[i]+shifts[i]:crossings[i+1]]; - -# Create and populate a volume matched HDF5-file from the original def write_matched_hdf5(h5_filename_in, h5_filename_out, crossings, shifts, compression='lzf'): + ''' + Create and populate a volume matched HDF5-file from the original. + + Parameters + ---------- + `h5_filename_in` : str + The input HDF5 file. + `h5_filename_out` : str + The output HDF5 file. + `crossings` : np.array[int] + The crossings between the regions. + `shifts` : np.array[int] + The shifts that minimize the squared differences. + `compression` : str + The compression to use. + + Returns + ------- + `None` + ''' + h5in = h5py.File(h5_filename_in, "r") h5out = h5py.File(h5_filename_out,"w") voxels_in = h5in['voxels'];