Skip to content

Commit

Permalink
#34 Added docstrings for the functions in 0300
Browse files Browse the repository at this point in the history
  • Loading branch information
carljohnsen committed Sep 17, 2024
1 parent 7095fe1 commit baa6983
Showing 1 changed file with 82 additions and 22 deletions.
104 changes: 82 additions & 22 deletions src/processing_steps/0300_volume_matcher.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#! /usr/bin/python3
'''
Script for matching volumes from the top and bottom of a multi-scan tomogram.
'''
import sys
sys.path.append(sys.path[0]+"/../")
from config.paths import hdf5_root
Expand All @@ -16,25 +20,26 @@
# - should probably be removed by kernel-operation around outer edge, before matching...

###########################
### Fix shifted volumes ###
###########################
'''
Find shift that minimizes squared differences with `overlap <= shift <= max_shift`.
Parameters
----------
`voxels_top` : jp.array[Any]
The top region to match.
`voxels_bot` : jp.array[Any]
The bottom region to match.
`overlap` : int
The overlap between the regions.
`max_shift` : int
The maximum shift to consider.
Returns
-------
`result` : Tuple[int, float]
The shift that minimizes the squared differences and the corresponding error.
'''

import h5py, sys, jax, os.path, pathlib, tqdm
import numpy as np
import jax.numpy as jp
import h5py, jax, sys
from PIL import Image
sys.path.append(sys.path[0]+"/../")
from config.paths import hdf5_root
from lib.py.helpers import commandline_args

verbose = 1
volume_matched_dir = f"{hdf5_root}/processed/volume_matched"

def match_region(voxels_top, voxels_bot, overlap, max_shift):
"""
Find shift that minimizes squared differences with overlap <= shift <= max_shift
"""
# Shifts smaller than the overlap overlap with shift
slice_size = voxels_top.shape[1]*voxels_top.shape[2] # Normalize by number of voxels (to make sums_lt and sums_ge comparable)
sums_lt = jp.array( [ jp.sum(((voxels_top[-shift:] - voxels_bot[0:shift])/(shift*slice_size))**2)
Expand All @@ -53,6 +58,24 @@ def match_region(voxels_top, voxels_bot, overlap, max_shift):


def match_all_regions(voxels,crossings,write_image_checks=True):
'''
Match all regions in a volume.
Parameters
----------
`voxels` : np.array[Any]
The volume to match.
`crossings` : np.array[int]
The crossings between the regions.
`write_image_checks` : bool
Whether to write images to check correctness.
Returns
-------
`shifts, errors` : Tuple[np.array[int], np.array[float]]
The shifts that minimize the squared differences and the corresponding errors.
'''

shifts = np.zeros(len(crossings),dtype=np.int32)
errors = np.zeros(len(crossings),dtype=np.float32)
match_region_jit = jax.jit(match_region,static_argnums=(2,3));
Expand Down Expand Up @@ -91,9 +114,27 @@ def match_all_regions(voxels,crossings,write_image_checks=True):
del bot_voxels
return shifts,errors

# Copy through the volume matched volume from the original - general interface
# that works equally well for anything indexed like numpy arrays - including cupy, HDF5 and netCDF.
def write_matched(voxels_in, voxels_out, crossings, shifts):
'''
Copy through the volume matched volume from the original.
- general interface that works equally well for anything indexed like numpy arrays - including cupy, HDF5 and netCDF.
Parameters
----------
`voxels_in` : np.array[Any]
The original volume.
`voxels_out` : np.array[Any]
The volume to write to.
`crossings` : np.array[int]
The crossings between the regions.
`shifts` : np.array[int]
The shifts that minimize the squared differences.
Returns
-------
`None`
'''

voxels_out[:crossings[0]] = voxels_in[:crossings[0]]
cum_shifts = [0]+list(np.cumsum(shifts))
crossings = list(crossings) + [voxels_in.shape[0]]
Expand All @@ -105,9 +146,28 @@ def write_matched(voxels_in, voxels_out, crossings, shifts):
if verbose >= 1: print(f"Duplicating unmatched part of subvolume {i+1}: voxels_out[{crossings[i]-cum_shifts[i]}:{crossings[i+1]-cum_shifts[i]-shifts[i]}] = voxels_in[{crossings[i]+shifts[i]}:{crossings[i+1]}];")
voxels_out[crossings[i]-cum_shifts[i]:crossings[i+1]-cum_shifts[i]-shifts[i]] = voxels_in[crossings[i]+shifts[i]:crossings[i+1]];


# Create and populate a volume matched HDF5-file from the original
def write_matched_hdf5(h5_filename_in, h5_filename_out, crossings, shifts, compression='lzf'):
'''
Create and populate a volume matched HDF5-file from the original.
Parameters
----------
`h5_filename_in` : str
The input HDF5 file.
`h5_filename_out` : str
The output HDF5 file.
`crossings` : np.array[int]
The crossings between the regions.
`shifts` : np.array[int]
The shifts that minimize the squared differences.
`compression` : str
The compression to use.
Returns
-------
`None`
'''

h5in = h5py.File(h5_filename_in, "r")
h5out = h5py.File(h5_filename_out,"w")
voxels_in = h5in['voxels'];
Expand Down

0 comments on commit baa6983

Please sign in to comment.