Skip to content

Commit

Permalink
FIX+REFACTOR: A bit of refactoring and a fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
brudfors committed Aug 6, 2021
1 parent 68eb379 commit d552ba3
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 18 deletions.
37 changes: 21 additions & 16 deletions unires/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
from nitorch.core.utils import ceil_pow
from nitorch.tools._preproc_fov import _bb_atlas
from nitorch.tools._preproc_utils import _mean_space
from torch._C import device
# UniRes
from ._project import _proj_info
from .struct import (_input, _output)
from ._util import (_print_info, _read_image, _write_image, _read_label)


def _all_mat_dim_vx(x, sett):
""" Get all images affine matrices, dimensions and voxel sizes (as numpy arrays).
""" Get all images affine matrices, dimensions and voxel sizes.
Returns:
all_mat (torch.tensor): Image orientation matrices (N, 4, 4).
Expand Down Expand Up @@ -207,10 +208,12 @@ def _format_y(x, sett):
if vx_y is None and ((N == 1) or vx_same): # One image, voxel size not given
vx_y = all_vx[0, ...]

do_pow = (isinstance(sett.pow, (tuple, list)) and len(sett.pow) == 3) \
or (isinstance(sett.pow, int) and sett.pow > 0)
if vx_same and (torch.abs(all_vx[0, ...] - vx_y) < 1e-3).all():
# All input images have same voxel size, and output voxel size is the also the same
do_sr = False
if mat_same and dim_same and not sett.unified_rigid:
if mat_same and dim_same and not sett.unified_rigid and not sett.crop and not do_pow:
# All input images have the same FOV
mat = all_mat[0, ...]
dim = all_dim[0, ...]
Expand All @@ -231,13 +234,15 @@ def _format_y(x, sett):
mat = mat_mu.mm(mat_vx)
dim = mat_vx[:3, :3].inverse().mm(dim[:, None]).floor().squeeze()

if sett.pow:
# Ensure output image dimensions are compatible with encode/decode
# architecture
dim2 = ceil_pow(dim, p=2.0, l=2.0, mx=256)
dim3 = ceil_pow(dim, p=2.0, l=3.0, mx=256)
ndim = dim2
ndim[dim3 < ndim] = dim3[dim3 < ndim]
if do_pow:
# crops output FOV to fixed dimensions
if isinstance(sett.pow, int):
dim2 = ceil_pow(dim, p=2.0, l=2.0, mx=sett.pow)
dim3 = ceil_pow(dim, p=2.0, l=3.0, mx=sett.pow)
ndim = dim2
ndim[dim3 < ndim] = dim3[dim3 < ndim]
else:
ndim = torch.as_tensor(sett.pow, device=sett.device)
# Modulate output affine
mat_bb = affine_matrix_classic(-((ndim - dim)/2).round())\
.type(torch.float64).to(sett.device)
Expand Down Expand Up @@ -318,7 +323,7 @@ def _init_reg(x, sett):
if sett.do_coreg and N > 1:
# Align images, pairwise, to fixed image (fix)
t0 = _print_info('init-reg', sett, 'co', 'begin', N)
mat_a = affine_align(imgs, fix=fix, device=sett.device)[1]
mat_a = affine_align(imgs, **sett.coreg_params, fix=fix, device=sett.device)[1]
# Apply coreg transform
i = 0
for c in range(len(x)):
Expand Down Expand Up @@ -380,9 +385,9 @@ def _init_y_dat(x, y, sett):
bound='zero', extrapolate=False, interpolation=1)
dat[dat < mn] = mn
dat[dat > mx] = mx
sm = sm + (dat[0, 0, ...].round() != 0)
sm = sm + (dat[0, 0, ...] > 0)
dat_y = dat_y + dat[0, 0, ...]
sm[sm == 0] = 1
sm[sm == 0] = 1.0
y[c].dat = dat_y / sm

return y
Expand Down Expand Up @@ -458,7 +463,7 @@ def _resample_inplane(x, sett):
# make grid
D = I.clone()
for i in range(3):
D[i, i] = sett.vx / vx_x[i]
D[i, i] = sett.vx[i] / vx_x[i]
if D[i, i] < 1.0:
D[i, i] = 1
if float((I - D).abs().sum()) < 1e-4:
Expand Down Expand Up @@ -615,7 +620,7 @@ def _write_data(x, y, sett, jtv=None):
if sett.write_out and sett.mat is None:
# Write reconstructed images (as separate niftis, because given as separate niftis)
if x[c][0].nam is None:
nam = str(c) + '.nii'
nam = str(c) + '.nii.gz'
else:
nam = x[c][0].nam
fname = os.path.join(dir_out, prefix_y + nam)
Expand All @@ -636,7 +641,7 @@ def _write_data(x, y, sett, jtv=None):
# Write reconstructed images as 4D volume (because given as 4D volume)
c = 0
if x[c][0].nam is None:
nam = str(c) + '.nii'
nam = str(c) + '.nii.gz'
else:
nam = x[c][0].nam
fname = os.path.join(dir_out, prefix_y + nam)
Expand All @@ -646,7 +651,7 @@ def _write_data(x, y, sett, jtv=None):
if sett.write_jtv and jtv is not None:
# Write JTV
if x[c][0].nam is None:
nam = str(c) + '.nii'
nam = str(c) + '.nii.gz'
else:
nam = x[c][0].nam
fname = os.path.join(dir_out, 'jtv_' + prefix_y + nam)
Expand Down
2 changes: 1 addition & 1 deletion unires/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def init(data, sett=settings()):
# Makes recons aligned with same grid, across subjects
sett.do_atlas_align = True
sett.crop = True
sett.pow = True
sett.pow = 256

# Read and format data (images and labels)
x = _read_data(data, sett)
Expand Down
3 changes: 2 additions & 1 deletion unires/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self):
self.cgs_tol: float = 1e-3 # CG tolerance for solving for y
self.cgs_verbose: bool = False # CG verbosity (0, 1)
self.clean_fov: bool = False # Set voxels outside of low-res FOV, projected in high-res space, to zero
self.coreg_params = {'cost_fun': 'nmi', 'group': 'SE', 'samp': (1), 'fwhm': 7, 'mean_space': False} # parameters for coregistration
self.crop: bool = False # Crop input images' FOV to brain in the NITorch atlas
self.common_output: bool = False # Makes recons aligned with same grid, across subjects
self.ct: bool = False # Data could be CT (if contain negative values)
Expand All @@ -88,7 +89,7 @@ def __init__(self):
self.max_iter: int = 512 # Max algorithm iterations
self.method = None # Method name (super-resolution|denoising), defined in format_output()
self.plot_conv: bool = False # Use matplotlib to plot convergence in real-time
self.pow: bool = False # Ensure output image dimensions are compatible with encode/decode architecture
self.pow: int = 0 # Ensure output image dimensions are a power of two or three, with max dimensions pow
self.prefix: str = 'ur_' # Prefix for reconstructed image(s)
self.profile_ip: int = 2 # In-plane slice profile (0=rect|1=tri|2=gauss)
self.profile_tp: int = 0 # Through-plane slice profile (0=rect|1=tri|2=gauss)
Expand Down

0 comments on commit d552ba3

Please sign in to comment.