From 1712647dfa79e3b2fdc94923f168ca44e24ea5f0 Mon Sep 17 00:00:00 2001 From: dkazanc Date: Mon, 7 Oct 2024 15:53:06 +0100 Subject: [PATCH] distortion correction preview refactoring --- httomolibgpu/prep/alignment.py | 36 +++++++++++++++---------------- tests/test_prep/test_alignment.py | 5 +++-- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/httomolibgpu/prep/alignment.py b/httomolibgpu/prep/alignment.py index 85be8cd4..2a5b6004 100644 --- a/httomolibgpu/prep/alignment.py +++ b/httomolibgpu/prep/alignment.py @@ -33,7 +33,7 @@ else: map_coordinates = Mock() -from typing import Dict, List +from typing import Dict, List, Tuple __all__ = [ "distortion_correction_proj_discorpy", @@ -48,7 +48,8 @@ def distortion_correction_proj_discorpy( data: cp.ndarray, metadata_path: str, - preview: Dict[str, List[int]], + shift: Tuple[int, int] = (0, 0), + step: Tuple[int, int] = (1, 1), order: int = 1, mode: str = "reflect", ): @@ -63,11 +64,11 @@ def distortion_correction_proj_discorpy( The path to the file containing the distortion coefficients for the data. - preview : Dict[str, List[int]] - A dict containing three key-value pairs: - - a list containing the `start` value of each dimension - - a list containing the `stop` value of each dimension - - a list containing the `step` value of each dimension + shift: tuple, optional + Centers of distortion in x (from the left of the image) and y directions (from the top of the image). + + step: tuple, optional + Steps in x and y directions respectively. They need to be not larger than one. order : int, optional. The order of the spline interpolation. @@ -90,12 +91,10 @@ def distortion_correction_proj_discorpy( # Use preview information to offset the x and y coords of the center of # distortion - shift = preview["starts"] - step = preview["steps"] - x_dim = 1 - y_dim = 0 - step_check = max([step[i] for i in [x_dim, y_dim]]) > 1 - if step_check: + det_x_step = step[0] + det_y_step = step[1] + + if det_y_step > 1 or det_x_step > 1: msg = ( "\n***********************************************\n" "!!! ERROR !!! -> Method doesn't work with the step in" @@ -104,12 +103,13 @@ def distortion_correction_proj_discorpy( ) raise ValueError(msg) - x_offset = shift[x_dim] - y_offset = shift[y_dim] - xcenter = xcenter - x_offset - ycenter = ycenter - y_offset + det_x_shift = shift[0] + det_y_shift = shift[1] + + xcenter = xcenter - det_x_shift + ycenter = ycenter - det_y_shift - height, width = data.shape[y_dim + 1], data.shape[x_dim + 1] + height, width = data.shape[1], data.shape[2] xu_list = cp.arange(width) - xcenter yu_list = cp.arange(height) - ycenter xu_mat, yu_mat = cp.meshgrid(xu_list, yu_list) diff --git a/tests/test_prep/test_alignment.py b/tests/test_prep/test_alignment.py index c4b484cf..3922f927 100644 --- a/tests/test_prep/test_alignment.py +++ b/tests/test_prep/test_alignment.py @@ -40,8 +40,9 @@ def test_correct_distortion( im_host = imread(path) im = cp.asarray(im_host) - preview = {"starts": [0, 0], "stops": [im.shape[0], im.shape[1]], "steps": [1, 1]} - corrected_data = implementation(im, distortion_coeffs_path, preview).get() + shift = (0, 0) + step = (1, 1) + corrected_data = implementation(im, distortion_coeffs_path, shift, step).get() assert_allclose(np.mean(corrected_data), mean_value) assert np.max(corrected_data) == max_value