From a55bce2130d011adcb15b0cccac94e8793d5310e Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 27 Jun 2024 16:28:33 +0200 Subject: [PATCH] Add kt sampling, allow for dynamic sampling --- direct/common/subsample.py | 1662 +++++++++++++++++++---- direct/common/subsample_config.py | 17 +- direct/data/mri_transforms.py | 5 +- direct/types.py | 12 + direct/utils/__init__.py | 56 + tests/test_train.py | 3 +- tests/tests_common/test_subsample.py | 309 +++-- tests/tests_data/test_mri_transforms.py | 15 +- 8 files changed, 1667 insertions(+), 412 deletions(-) diff --git a/direct/common/subsample.py b/direct/common/subsample.py index 058a36567..bd75eb559 100644 --- a/direct/common/subsample.py +++ b/direct/common/subsample.py @@ -19,18 +19,50 @@ import numpy as np import torch +from scipy.linalg import toeplitz +from scipy.ndimage import rotate import direct.data.transforms as T from direct.common._gaussian import gaussian_mask_1d, gaussian_mask_2d # pylint: disable=no-name-in-module from direct.common._poisson import poisson as _poisson # pylint: disable=no-name-in-module from direct.environment import DIRECT_CACHE_DIR -from direct.types import DirectEnum, Number -from direct.utils import str_to_class +from direct.types import DirectEnum, MaskFuncMode, Number +from direct.utils import reshape_array_to_shape, str_to_class from direct.utils.io import download_url # pylint: disable=arguments-differ +__all__ = [ + "BaseMaskFunc", + "CIRCUSMaskFunc", + "CIRCUSSamplingMode", + "CalgaryCampinasMaskFunc", + "CartesianEquispacedMaskFunc", + "CartesianMagicMaskFunc", + "CartesianRandomMaskFunc", + "CartesianVerticalMaskFunc", + "EquispacedMaskFunc", + "FastMRIEquispacedMaskFunc", + "FastMRIMagicMaskFunc", + "FastMRIRandomMaskFunc", + "Gaussian1DMaskFunc", + "Gaussian2DMaskFunc", + "KtBaseMaskFunc", + "KtGaussian1DMaskFunc", + "KtRadialMaskFunc", + "KtUniformMaskFunc", + "MagicMaskFunc", + "MaskFuncMode", + "RadialMaskFunc", + "RandomMaskFunc", + "SpiralMaskFunc", + "VariableDensityPoissonMaskFunc", + "build_masking_function", + "centered_disk_mask", + "integerize_seed", +] + logger = logging.getLogger(__name__) GOLDEN_RATIO = (1 + np.sqrt(5)) / 2 @@ -59,6 +91,13 @@ class BaseMaskFunc: is True, then two values should be given. Default: None. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ def __init__( @@ -66,6 +105,7 @@ def __init__( accelerations: Union[list[Number], tuple[Number, ...]], center_fractions: Optional[Union[list[float], tuple[float, ...]]] = None, uniform_range: bool = True, + mode: MaskFuncMode = MaskFuncMode.STATIC, ) -> None: """Inits :class:`BaseMaskFunc`. @@ -80,6 +120,13 @@ def __init__( is True, then two values should be given. Default: None. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ if center_fractions is not None: if len([center_fractions]) != len([accelerations]): @@ -92,6 +139,7 @@ def __init__( self.accelerations = accelerations self.uniform_range = uniform_range + self.mode = mode self.rng = np.random.RandomState() @@ -136,15 +184,30 @@ def mask_func(self, *args, **kwargs) -> torch.Tensor: """ raise NotImplementedError("This method should be implemented by a child class.") - def __call__(self, *args, **kwargs) -> torch.Tensor: + def __call__(self, shape: tuple[int, ...], *args, **kwargs) -> torch.Tensor: """Calls the mask function. + Parameters + ---------- + shape : tuple of ints + Shape of the mask to be created. Needs to be at least 3 dimensions. If mode is MaskFuncMode.DYNAMIC, + or MaskFuncMode.MULTISLICE, then the shape should have at least 4 dimensions. + args : tuple + Additional arguments to be passed to the mask function. + kwargs : dict + Additional keyword arguments to be passed to the mask function. + Returns ------- torch.Tensor Sampling mask. """ - mask = self.mask_func(*args, **kwargs) + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions.") + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] and len(shape) < 4: + raise ValueError("Shape should have 4 or more dimensions for dynamic or multislice mode.") + + mask = self.mask_func(shape, *args, **kwargs) return mask @@ -163,6 +226,13 @@ class CartesianVerticalMaskFunc(BaseMaskFunc): Fraction of low-frequency columns (float < 1.0) or number of low-frequence columns (integer) to be retained. uniform_range : bool, optional If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ def __init__( @@ -170,6 +240,7 @@ def __init__( accelerations: Union[list[Number], tuple[Number, ...]], center_fractions: Union[list[float], tuple[float, ...]], uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, ) -> None: """Inits :class:`CartesianVerticalMaskFunc`. @@ -181,11 +252,19 @@ def __init__( Fraction of low-frequency columns (float < 1.0) or number of low-frequence columns (integer) to be retained. uniform_range : bool, optional If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ super().__init__( accelerations=accelerations, center_fractions=center_fractions, uniform_range=uniform_range, + mode=mode, ) @staticmethod @@ -211,8 +290,7 @@ def center_mask_func(num_cols: int, num_low_freqs: int) -> np.ndarray: return mask - @staticmethod - def _reshape_and_broadcast_mask(shape: tuple[int, ...], mask: np.ndarray) -> np.ndarray: + def _reshape_and_broadcast_mask(self, shape: tuple[int, ...], mask: np.ndarray) -> np.ndarray: """Broadcasts and reshapes the mask to the shape of the input k-space data. Parameters @@ -233,6 +311,8 @@ def _reshape_and_broadcast_mask(shape: tuple[int, ...], mask: np.ndarray) -> np. # Reshape the mask mask_shape = [1 for _ in shape] mask_shape[-2] = num_cols + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: + mask_shape[-4] = shape[-4] mask = mask.reshape(*mask_shape).astype(bool) mask_shape[-3] = num_rows @@ -242,14 +322,15 @@ def _reshape_and_broadcast_mask(shape: tuple[int, ...], mask: np.ndarray) -> np. return mask -class FastMRIRandomMaskFunc(CartesianVerticalMaskFunc): +class RandomMaskFunc(CartesianVerticalMaskFunc): r"""Random vertical line mask function. The mask selects a subset of columns from the input k-space data. If the k-space data has :math:`N` columns, the mask picks out: - #. :math:`N_{\text{low freqs}} = (N \times \text{center_fraction})` columns in the center corresponding - to low-frequencies. + #. :math:`N_{\text{low freqs}} = (N \times \text{center_fraction})` columns in the center corresponding + to low-frequencies if center_fraction < 1.0, or :math:`N_{\text{low freqs}} = \text{center_fraction}` + if center_fraction >= 1 and is integer. #. The other columns are selected uniformly at random with a probability equal to: :math:`\text{prob} = (N / \text{acceleration} - N_{\text{low freqs}}) / (N - N_{\text{low freqs}})`. This ensures that the expected number of columns selected is equal to (N / acceleration). @@ -266,33 +347,51 @@ class FastMRIRandomMaskFunc(CartesianVerticalMaskFunc): ---------- accelerations : Union[list[Number], tuple[Number, ...]] Amount of under-sampling. - center_fractions : Union[list[float], tuple[float, ...]] - Fraction of low-frequency columns (float < 1.0) to be retained. + center_fractions : Union[list[Number], tuple[Number, ...]] + If < 1.0 this corresponds to the fraction of low-frequency columns to be retained. + If >= 1 (integer) this corresponds to the exact number of low-frequency columns to be retained. uniform_range : bool, optional If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ def __init__( self, accelerations: Union[list[Number], tuple[Number, ...]], - center_fractions: Union[list[float], tuple[float, ...]], + center_fractions: Union[list[Number], tuple[Number, ...]], uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, ) -> None: - """Inits :class:`FastMRIRandomMaskFunc`. + """Inits :class:`RandomMaskFunc`. Parameters ---------- accelerations : Union[list[Number], tuple[Number, ...]] Amount of under-sampling. - center_fractions : Union[list[float], tuple[float, ...]] - Fraction of low-frequency columns (float < 1.0) to be retained. + center_fractions : Union[list[Number], tuple[Number, ...]] + If < 1.0 this corresponds to the fraction of low-frequency columns to be retained. + If >= 1 (integer) this corresponds to the exact number of low-frequency columns to be retained. uniform_range : bool, optional If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ super().__init__( accelerations=accelerations, center_fractions=center_fractions, uniform_range=uniform_range, + mode=mode, ) def mask_func( @@ -320,28 +419,111 @@ def mask_func( mask : torch.Tensor The sampling mask. """ - if len(shape) < 3: - raise ValueError("Shape should have 3 or more dimensions") + num_cols = shape[-2] + num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): - num_cols = shape[-2] center_fraction, acceleration = self.choose_acceleration() - num_low_freqs = int(round(num_cols * center_fraction)) + + if center_fraction < 1.0: + num_low_freqs = int(round(num_cols * center_fraction)) + else: + num_low_freqs = int(center_fraction) mask = self.center_mask_func(num_cols, num_low_freqs) + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: + mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0) + if return_acs: return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) # Create the mask - prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) - mask = mask | (self.rng.uniform(size=num_cols) < prob) + mask = mask.reshape(num_slc_or_time, -1) # In case mode != MaskFuncMode.STATIC: + for i in range(num_slc_or_time): + prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) + mask[i] = mask[i] | (self.rng.uniform(size=num_cols) < prob) return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) -class CartesianRandomMaskFunc(CartesianVerticalMaskFunc): +class FastMRIRandomMaskFunc(RandomMaskFunc): + r"""Random vertical line mask function. + + The mask selects a subset of columns from the input k-space data. If the k-space data has :math:`N` columns, + the mask picks out: + + #. :math:`N_{\text{low freqs}} = (N \times \text{center_fraction})` columns in the center corresponding + to low-frequencies. + #. The other columns are selected uniformly at random with a probability equal to: + :math:`\text{prob} = (N / \text{acceleration} - N_{\text{low freqs}}) / (N - N_{\text{low freqs}})`. + This ensures that the expected number of columns selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which case one possible + (center_fraction, acceleration) is chosen uniformly at random each time the MaskFunc object is + called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there + is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50% + probability that 8-fold acceleration with 4% center fraction is selected. + + Parameters + ---------- + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]] + Fraction of low-frequency columns (float < 1.0) to be retained. + uniform_range : bool, optional + If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. + """ + + def __init__( + self, + accelerations: Union[list[Number], tuple[Number, ...]], + center_fractions: Union[list[float], tuple[float, ...]], + uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, + ) -> None: + """Inits :class:`FastMRIRandomMaskFunc`. + + Parameters + ---------- + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]] + Fraction of low-frequency columns (float < 1.0) to be retained. + uniform_range : bool, optional + If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. + """ + if not all(0 < center_fraction < 1 for center_fraction in center_fractions): + raise ValueError( + f"Center fraction values should be between 0 and 1. Received {center_fractions}. " + f"For exact number of center lines, use `CartesianMagicMaskFunc`." + ) + super().__init__( + accelerations=accelerations, + center_fractions=center_fractions, + uniform_range=uniform_range, + mode=mode, + ) + + +class CartesianRandomMaskFunc(RandomMaskFunc): r"""Cartesian random vertical line mask function. Similar to :class:`FastMRIRandomMaskFunc`, but instead of center fraction (`center_fractions`) representing @@ -356,6 +538,13 @@ class CartesianRandomMaskFunc(CartesianVerticalMaskFunc): Number of low-frequence (center) columns to be retained. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ def __init__( @@ -363,6 +552,7 @@ def __init__( accelerations: Union[list[Number], tuple[Number, ...]], center_fractions: Union[list[int], tuple[int, ...]], uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, ) -> None: """Inits :class:`CartesianRandomMaskFunc`. @@ -375,11 +565,96 @@ def __init__( Number of low-frequence (center) columns to be retained. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. + """ + if not all((1 < center_fraction) and isinstance(center_fraction, int) for center_fraction in center_fractions): + raise ValueError( + f"Center fraction values should be integers greater then or equal to 1 corresponding to the number of " + f"center lines. Received {center_fractions}. For fractions, use `FastMRIMagicMaskFunc`." + ) + super().__init__( + accelerations=accelerations, + center_fractions=center_fractions, + uniform_range=uniform_range, + mode=mode, + ) + + +class EquispacedMaskFunc(CartesianVerticalMaskFunc): + r"""Equispaced vertical line mask function. + + :class:`EquispacedMaskFunc` creates a sub-sampling mask of given shape. The mask selects a subset of columns + from the input k-space data. If the k-space data has N columns, the mask picks out: + + #. :math:`N_{\text{low freqs}} = (N \times \text{center_fraction})` columns in the center corresponding + to low-frequencies if center_fraction < 1.0, or :math:`N_{\text{low freqs}} = \text{center_fraction}` + if center_fraction >= 1 and is integer. + #. The other columns are selected with equal spacing at a proportion that reaches the desired acceleration + rate taking into consideration the number of low frequencies. This ensures that the expected number of + columns selected is equal to :math:`\frac{N}{\text{acceleration}}`. + + It is possible to use multiple center_fractions and accelerations, in which case one possible + (center_fraction, acceleration) is chosen uniformly at random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require modifications to standard GRAPPA + approaches. Nonetheless, this aspect of the function has been preserved to match the public multicoil data. + + Parameters + ---------- + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling. + center_fractions : Union[list[Number], tuple[Number, ...]] + If < 1.0 this corresponds to the fraction of low-frequency columns to be retained. + If >= 1 (integer) this corresponds to the exact number of low-frequency columns to be retained. + uniform_range : bool, optional + If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. + """ + + def __init__( + self, + accelerations: Union[list[Number], tuple[Number, ...]], + center_fractions: Union[list[Number], tuple[Number, ...]], + uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, + ) -> None: + """Inits :class:`EquispacedMaskFunc`. + + Parameters + ---------- + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling. + center_fractions : Union[list[Number], tuple[Number, ...]] + If < 1.0 this corresponds to the fraction of low-frequency columns to be retained. + If >= 1 (integer) this corresponds to the exact number of low-frequency columns to be retained. + uniform_range : bool, optional + If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ super().__init__( accelerations=accelerations, center_fractions=center_fractions, uniform_range=uniform_range, + mode=mode, ) def mask_func( @@ -388,7 +663,7 @@ def mask_func( return_acs: bool = False, seed: Optional[Union[int, Iterable[int]]] = None, ) -> torch.Tensor: - """Creates a random vertical Cartesian mask. + """Creates an vertical equispaced vertical line mask. Parameters ---------- @@ -401,33 +676,45 @@ def mask_func( Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. Default: None. - Returns ------- mask : torch.Tensor The sampling mask. """ - if len(shape) < 3: - raise ValueError("Shape should have 3 or more dimensions") + num_cols = shape[-2] + num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): - num_cols = shape[-2] - num_center_lines, acceleration = self.choose_acceleration() + center_fraction, acceleration = self.choose_acceleration() + + if center_fraction < 1.0: + num_low_freqs = int(round(num_cols * center_fraction)) + else: + num_low_freqs = int(center_fraction) + + mask = self.center_mask_func(num_cols, num_low_freqs) - mask = self.center_mask_func(num_cols, num_center_lines) + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: + mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0) if return_acs: return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) - # Create the mask - prob = (num_cols / acceleration - num_center_lines) / (num_cols - num_center_lines) - mask = mask | (self.rng.uniform(size=num_cols) < prob) + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols) + + mask = mask.reshape(num_slc_or_time, -1) # In case mode != MaskFuncMode.STATIC: + for i in range(num_slc_or_time): + offset = self.rng.randint(0, round(adjusted_accel)) + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[i, accel_samples] = True return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) -class FastMRIEquispacedMaskFunc(CartesianVerticalMaskFunc): +class FastMRIEquispacedMaskFunc(EquispacedMaskFunc): r"""Equispaced vertical line mask function. :class:`FastMRIEquispacedMaskFunc` creates a sub-sampling mask of given shape. The mask selects a subset of columns @@ -454,6 +741,13 @@ class FastMRIEquispacedMaskFunc(CartesianVerticalMaskFunc): Fraction (< 1.0) of low-frequency columns to be retained. uniform_range : bool, optional If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ def __init__( @@ -461,6 +755,7 @@ def __init__( accelerations: Union[list[Number], tuple[Number, ...]], center_fractions: Union[list[float], tuple[float, ...]], uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, ) -> None: """Inits :class:`FastMRIEquispacedMaskFunc`. @@ -472,63 +767,28 @@ def __init__( Fraction (< 1.0) of low-frequency columns to be retained. uniform_range : bool, optional If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ + if not all(0 < center_fraction < 1 for center_fraction in center_fractions): + raise ValueError( + f"Center fraction values should be between 0 and 1. Received {center_fractions}. " + f"For exact number of center lines, use `CartesianMagicMaskFunc`." + ) super().__init__( accelerations=accelerations, center_fractions=center_fractions, uniform_range=uniform_range, + mode=mode, ) - def mask_func( - self, - shape: Union[list[int], tuple[int, ...]], - return_acs: bool = False, - seed: Optional[Union[int, Iterable[int]]] = None, - ) -> torch.Tensor: - """Creates an vertical equispaced vertical line mask. - - Parameters - ---------- - shape : list or tuple of ints - The shape of the mask to be created. The shape should at least 3 dimensions. - Samples are drawn along the second last dimension. - return_acs : bool - Return the autocalibration signal region as a mask. - seed : int or iterable of ints or None (optional) - Seed for the random number generator. Setting the seed ensures the same mask is generated - each time for the same shape. Default: None. - - Returns - ------- - mask : torch.Tensor - The sampling mask. - """ - if len(shape) < 3: - raise ValueError("Shape should have 3 or more dimensions") - - with temp_seed(self.rng, seed): - num_cols = shape[-2] - - center_fraction, acceleration = self.choose_acceleration() - num_low_freqs = int(round(num_cols * center_fraction)) - - mask = self.center_mask_func(num_cols, num_low_freqs) - - if return_acs: - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) - - # determine acceleration rate by adjusting for the number of low frequencies - adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols) - offset = self.rng.randint(0, round(adjusted_accel)) - - accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) - accel_samples = np.around(accel_samples).astype(np.uint) - mask[accel_samples] = True - - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) - -class CartesianEquispacedMaskFunc(CartesianVerticalMaskFunc): +class CartesianEquispacedMaskFunc(EquispacedMaskFunc): r"""Cartesian equispaced vertical line mask function. Similar to :class:`FastMRIEquispacedMaskFunc`, but instead of center fraction (`center_fractions`) representing @@ -543,6 +803,13 @@ class CartesianEquispacedMaskFunc(CartesianVerticalMaskFunc): Number of low-frequence (center) columns to be retained. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ def __init__( @@ -550,6 +817,7 @@ def __init__( accelerations: Union[list[Number], tuple[Number, ...]], center_fractions: Union[list[int], tuple[int, ...]], uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, ) -> None: """Inits :class:`CartesianEquispacedMaskFunc`. @@ -562,71 +830,33 @@ def __init__( Number of low-frequence (center) columns to be retained. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ + if not all((1 < center_fraction) and isinstance(center_fraction, int) for center_fraction in center_fractions): + raise ValueError( + f"Center fraction values should be integers greater then or equal to 1 corresponding to the number of " + f"center lines. Received {center_fractions}. For fractions, use `FastMRIMagicMaskFunc`." + ) super().__init__( accelerations=accelerations, center_fractions=center_fractions, uniform_range=uniform_range, + mode=mode, ) - def mask_func( - self, - shape: Union[list[int], tuple[int, ...]], - return_acs: bool = False, - seed: Optional[Union[int, Iterable[int]]] = None, - ) -> torch.Tensor: - """Creates an equispaced vertical Cartesian mask. - - Parameters - ---------- - shape : list or tuple of ints - The shape of the mask to be created. The shape should at least 3 dimensions. - Samples are drawn along the second last dimension. - return_acs : bool - Return the autocalibration signal region as a mask. - seed : int or iterable of ints or None (optional) - Seed for the random number generator. Setting the seed ensures the same mask is generated - each time for the same shape. Default: None. - - - Returns - ------- - mask : torch.Tensor - The sampling mask. - """ - if len(shape) < 3: - raise ValueError("Shape should have 3 or more dimensions") - - with temp_seed(self.rng, seed): - num_cols = shape[-2] - - num_center_lines, acceleration = self.choose_acceleration() - - num_center_lines = int(num_center_lines) - mask = self.center_mask_func(num_cols, num_center_lines) - if return_acs: - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) - - # determine acceleration rate by adjusting for the number of low frequencies - adjusted_accel = (acceleration * (num_center_lines - num_cols)) / ( - num_center_lines * acceleration - num_cols - ) - offset = self.rng.randint(0, round(adjusted_accel)) - - accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) - accel_samples = np.around(accel_samples).astype(np.uint) - mask[accel_samples] = True - - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) - - -class FastMRIMagicMaskFunc(CartesianVerticalMaskFunc): +class MagicMaskFunc(CartesianVerticalMaskFunc): """Vertical line mask function as implemented in [1]_. - :class:`FastMRIMagicMaskFunc` exploits the conjugate symmetry via offset-sampling. It is essentially an + :class:`MagicMaskFunc` exploits the conjugate symmetry via offset-sampling. It is essentially an equispaced mask with an offset for the opposite site of the k-space. Since MRI images often exhibit approximate - conjugate k-space symmetry, this mask is generally more efficient than :class:`FastMRIEquispacedMaskFunc`. + conjugate k-space symmetry, this mask is generally more efficient than :class:`EquispacedMaskFunc`. References ---------- @@ -638,34 +868,52 @@ class FastMRIMagicMaskFunc(CartesianVerticalMaskFunc): accelerations : Union[list[Number], tuple[Number, ...]] Amount of under-sampling_mask. An acceleration of 4 retains 25% of the k-space, the method is given by mask_type. Has to be the same length as center_fractions if uniform_range is not True. - center_fractions : Union[list[float], tuple[float, ...]] - Fraction (< 1.0) of low-frequency columns to be retained. + center_fractions : Union[list[Number], tuple[Number, ...]] + If < 1.0 this corresponds to the fraction of low-frequency columns to be retained. + If >= 1 (integer) this corresponds to the exact number of low-frequency columns to be retained. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ def __init__( self, accelerations: Union[list[Number], tuple[Number, ...]], - center_fractions: Union[list[float], tuple[float, ...]], + center_fractions: Union[list[Number], tuple[Number, ...]], uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, ) -> None: - """Inits :class:`FastMRIMagicMaskFunc`. + """Inits :class:`MagicMaskFunc`. Parameters ---------- accelerations : Union[list[Number], tuple[Number, ...]] Amount of under-sampling_mask. An acceleration of 4 retains 25% of the k-space, the method is given by mask_type. Has to be the same length as center_fractions if uniform_range is not True. - center_fractions : Union[list[float], tuple[float, ...]] - Fraction (< 1.0) of low-frequency columns to be retained. + center_fractions : Union[list[Number], tuple[Number, ...]] + If < 1.0 this corresponds to the fraction of low-frequency columns to be retained. + If >= 1.0 this corresponds to the exact number of low-frequency columns to be retained. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ super().__init__( accelerations=accelerations, center_fractions=center_fractions, uniform_range=uniform_range, + mode=mode, ) def mask_func( @@ -692,21 +940,30 @@ def mask_func( mask : torch.Tensor The sampling mask. """ - if len(shape) < 3: - raise ValueError("Shape should have 3 or more dimensions") + num_cols = shape[-2] + num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): - num_cols = shape[-2] center_fraction, acceleration = self.choose_acceleration() - num_low_freqs = int(round(num_cols * center_fraction)) + # This is essentially for CartesianMagicMaskFunc, indicating the excact number of low frequency lines + # to be retained. + if center_fraction > 1: + num_low_freqs = center_fraction + # Otherwise, if < 1, it is the fraction of low frequency lines to be retained, for FastMRIMagicMaskFunc. + else: + num_low_freqs = int(round(num_cols * center_fraction)) + # bound the number of low frequencies between 1 and target columns target_cols_to_sample = int(round(num_cols / acceleration)) num_low_freqs = max(min(num_low_freqs, target_cols_to_sample), 1) acs_mask = self.center_mask_func(num_cols, num_low_freqs) + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: + acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0) + if return_acs: return torch.from_numpy(self._reshape_and_broadcast_mask(shape, acs_mask)) @@ -716,145 +973,170 @@ def mask_func( if adjusted_target_cols_to_sample > 0: adjusted_acceleration = int(round(num_cols / adjusted_target_cols_to_sample)) - offset = self.rng.randint(0, high=adjusted_acceleration) + acs_mask = acs_mask.reshape(num_slc_or_time, -1) # In case mode != MaskFuncMode.STATIC: + mask = [] + for i in range(num_slc_or_time): + offset = self.rng.randint(0, high=adjusted_acceleration) - if offset % 2 == 0: - offset_pos = offset + 1 - offset_neg = offset + 2 - else: - offset_pos = offset - 1 + 3 - offset_neg = offset - 1 + 0 + if offset % 2 == 0: + offset_pos = offset + 1 + offset_neg = offset + 2 + else: + offset_pos = offset - 1 + 3 + offset_neg = offset - 1 + 0 - poslen = (num_cols + 1) // 2 - neglen = num_cols - (num_cols + 1) // 2 - mask_positive = np.zeros(poslen, dtype=bool) - mask_negative = np.zeros(neglen, dtype=bool) + poslen = (num_cols + 1) // 2 + neglen = num_cols - (num_cols + 1) // 2 + mask_positive = np.zeros(poslen, dtype=bool) + mask_negative = np.zeros(neglen, dtype=bool) - mask_positive[offset_pos::adjusted_acceleration] = True - mask_negative[offset_neg::adjusted_acceleration] = True - mask_negative = np.flip(mask_negative) + mask_positive[offset_pos::adjusted_acceleration] = True + mask_negative[offset_neg::adjusted_acceleration] = True + mask_negative = np.flip(mask_negative) - mask = np.fft.fftshift(np.concatenate((mask_positive, mask_negative))) - mask = np.logical_or(mask, acs_mask) + mask.append(np.fft.fftshift(np.concatenate((mask_positive, mask_negative)))) + mask[i] = np.logical_or(mask[i], acs_mask[i]) + mask = np.stack(mask, axis=0).squeeze() return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) -class CartesianMagicMaskFunc(CartesianVerticalMaskFunc): - r"""Cartesian equispaced mask function as implemented in [1]_. +class FastMRIMagicMaskFunc(MagicMaskFunc): + """Vertical line mask function as implemented in [1]_. - Similar to :class:`FastMRIMagicMaskFunc`, but instead of center fraction (`center_fractions`) representing - the fraction of center lines to the original size, here, it represents the actual number of center lines. + :class:`FastMRIMagicMaskFunc` exploits the conjugate symmetry via offset-sampling. It is essentially an + equispaced mask with an offset for the opposite site of the k-space. Since MRI images often exhibit approximate + conjugate k-space symmetry, this mask is generally more efficient than :class:`FastMRIEquispacedMaskFunc`. References ---------- .. [1] Defazio, Aaron. “Offset Sampling Improves Deep Learning Based Accelerated MRI Reconstructions by - Exploiting Symmetry.” ArXiv:1912.01101 [Cs, Eess], Feb. 2020. arXiv.org, http://arxiv.org/abs/1912.01101. + Exploiting Symmetry.” ArXiv:1912.01101 [Cs, Eess], Feb. 2020. arXiv.org, http://arxiv.org/abs/1912.01101. Parameters ---------- accelerations : Union[list[Number], tuple[Number, ...]] Amount of under-sampling_mask. An acceleration of 4 retains 25% of the k-space, the method is given by mask_type. Has to be the same length as center_fractions if uniform_range is not True. - center_fractions : Union[list[int], tuple[int, ...]] - Number of low-frequence (center) columns to be retained. + center_fractions : Union[list[float], tuple[float, ...]] + Fraction (< 1.0) of low-frequency columns to be retained. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ def __init__( self, accelerations: Union[list[Number], tuple[Number, ...]], - center_fractions: Union[list[int], tuple[int, ...]], + center_fractions: Union[list[float], tuple[float, ...]], uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, ) -> None: - """Inits :class:`CartesianMagicMaskFunc`. + """Inits :class:`FastMRIMagicMaskFunc`. Parameters ---------- accelerations : Union[list[Number], tuple[Number, ...]] Amount of under-sampling_mask. An acceleration of 4 retains 25% of the k-space, the method is given by mask_type. Has to be the same length as center_fractions if uniform_range is not True. - center_fractions : Union[list[int], tuple[int, ...]] - Number of low-frequence (center) columns to be retained. + center_fractions : Union[list[float], tuple[float, ...]] + Fraction (< 1.0) of low-frequency columns to be retained. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ + if not all(0 < center_fraction < 1 for center_fraction in center_fractions): + raise ValueError( + f"Center fraction values should be between 0 and 1. Received {center_fractions}. " + f"For exact number of center lines, use `CartesianMagicMaskFunc`." + ) + super().__init__( accelerations=accelerations, center_fractions=center_fractions, uniform_range=uniform_range, + mode=mode, ) - def mask_func( - self, - shape: Union[list[int], tuple[int, ...]], - return_acs: bool = False, - seed: Optional[Union[int, Iterable[int]]] = None, - ) -> torch.Tensor: - r"""Creates an equispaced Cartesian mask that exploits conjugate symmetry. - - Parameters - ---------- - shape : list or tuple of ints - The shape of the mask to be created. The shape should at least 3 dimensions. - Samples are drawn along the second last dimension. - return_acs : bool - Return the autocalibration signal region as a mask. - seed : int or iterable of ints or None (optional) - Seed for the random number generator. Setting the seed ensures the same mask is generated - each time for the same shape. Default: None. - - Returns - ------- - mask : torch.Tensor - The sampling mask. - """ - if len(shape) < 3: - raise ValueError("Shape should have 3 or more dimensions") - - with temp_seed(self.rng, seed): - num_cols = shape[-2] - - num_center_lines, acceleration = self.choose_acceleration() - - # bound the number of low frequencies between 1 and target columns - target_cols_to_sample = int(round(num_cols / acceleration)) - num_center_lines = max(min(num_center_lines, target_cols_to_sample), 1) - acs_mask = self.center_mask_func(num_cols, num_center_lines) - - if return_acs: - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, acs_mask)) - - # adjust acceleration rate based on target acceleration. - adjusted_target_cols_to_sample = target_cols_to_sample - num_center_lines - adjusted_acceleration = 0 - if adjusted_target_cols_to_sample > 0: - adjusted_acceleration = int(round(num_cols / adjusted_target_cols_to_sample)) - - offset = self.rng.randint(0, high=adjusted_acceleration) +class CartesianMagicMaskFunc(MagicMaskFunc): + r"""Cartesian equispaced mask function as implemented in [1]_. - if offset % 2 == 0: - offset_pos = offset + 1 - offset_neg = offset + 2 - else: - offset_pos = offset - 1 + 3 - offset_neg = offset - 1 + 0 + Similar to :class:`FastMRIMagicMaskFunc`, but instead of center fraction (`center_fractions`) representing + the fraction of center lines to the original size, here, it represents the actual number of center lines. - poslen = (num_cols + 1) // 2 - neglen = num_cols - (num_cols + 1) // 2 - mask_positive = np.zeros(poslen, dtype=bool) - mask_negative = np.zeros(neglen, dtype=bool) + References + ---------- + .. [1] Defazio, Aaron. “Offset Sampling Improves Deep Learning Based Accelerated MRI Reconstructions by + Exploiting Symmetry.” ArXiv:1912.01101 [Cs, Eess], Feb. 2020. arXiv.org, http://arxiv.org/abs/1912.01101. - mask_positive[offset_pos::adjusted_acceleration] = True - mask_negative[offset_neg::adjusted_acceleration] = True - mask_negative = np.flip(mask_negative) + Parameters + ---------- + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling_mask. An acceleration of 4 retains 25% of the k-space, the method is given by + mask_type. Has to be the same length as center_fractions if uniform_range is not True. + center_fractions : Union[list[int], tuple[int, ...]] + Number of low-frequence (center) columns to be retained. + uniform_range : bool + If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. + """ - mask = np.fft.fftshift(np.concatenate((mask_positive, mask_negative))) - mask = np.logical_or(mask, acs_mask) + def __init__( + self, + accelerations: Union[list[Number], tuple[Number, ...]], + center_fractions: Union[list[int], tuple[int, ...]], + uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, + ) -> None: + """Inits :class:`CartesianMagicMaskFunc`. - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + Parameters + ---------- + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling_mask. An acceleration of 4 retains 25% of the k-space, the method is given by + mask_type. Has to be the same length as center_fractions if uniform_range is not True. + center_fractions : Union[list[int], tuple[int, ...]] + Number of low-frequence (center) columns to be retained. + uniform_range : bool + If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. + """ + if not all((1 < center_fraction) and isinstance(center_fraction, int) for center_fraction in center_fractions): + raise ValueError( + f"Center fraction values should be integers greater then or equal to 1 corresponding to the number of " + f"center lines. Received {center_fractions}. For fractions, use `FastMRIMagicMaskFunc`." + ) + super().__init__( + accelerations=accelerations, + center_fractions=center_fractions, + uniform_range=uniform_range, + mode=mode, + ) class CalgaryCampinasMaskFunc(BaseMaskFunc): @@ -1043,13 +1325,23 @@ class CIRCUSMaskFunc(BaseMaskFunc): Parameters ---------- - accelerations : Union[list[Number], tuple[Number, ...]] - Amount of under-sampling. subsampling_scheme : CIRCUSSamplingMode The subsampling scheme to use. Can be either `CIRCUSSamplingMode.CIRCUS_RADIAL` or `CIRCUSSamplingMode.CIRCUS_SPIRAL`. + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]], optional + Fraction (< 1.0) of low-frequency samples to be retained. If None, it will calculate the acs mask based on the + maximum masked disk in the generated mask (with a tolerance).Default: None. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. References ---------- @@ -1061,21 +1353,33 @@ class CIRCUSMaskFunc(BaseMaskFunc): def __init__( self, - accelerations: Union[list[Number], tuple[Number, ...]], subsampling_scheme: CIRCUSSamplingMode, + accelerations: Union[list[Number], tuple[Number, ...]], + center_fractions: Optional[Union[list[float], tuple[float, ...]]] = None, uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, ) -> None: """Inits :class:`CIRCUSMaskFunc`. Parameters ---------- - accelerations : Union[list[Number], tuple[Number, ...]] - Amount of under-sampling. subsampling_scheme : CIRCUSSamplingMode The subsampling scheme to use. Can be either `CIRCUSSamplingMode.CIRCUS_RADIAL` or `CIRCUSSamplingMode.CIRCUS_SPIRAL`. + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]], optional + Fraction (< 1.0) of low-frequency samples to be retained. If None, it will calculate the acs mask + based on the maximum masked disk in the generated mask (with a tolerance).Default: None. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. Raises ------ @@ -1084,8 +1388,9 @@ def __init__( """ super().__init__( accelerations=accelerations, - center_fractions=tuple(0 for _ in range(len(accelerations))), + center_fractions=center_fractions if center_fractions else tuple(0 for _ in range(len(accelerations))), uniform_range=uniform_range, + mode=mode, ) if subsampling_scheme not in [CIRCUSSamplingMode.CIRCUS_RADIAL, CIRCUSSamplingMode.CIRCUS_SPIRAL]: raise NotImplementedError( @@ -1285,29 +1590,73 @@ def mask_func( The sampling mask. """ - if len(shape) < 3: - raise ValueError("Shape should have 3 or more dimensions") + num_rows = shape[-3] + num_cols = shape[-2] + + num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): - num_rows = shape[-3] - num_cols = shape[-2] - acceleration = self.choose_acceleration()[1] - - if self.subsampling_scheme == "circus-radial": - mask = self.circus_radial_mask( - shape=(num_rows, num_cols), - acceleration=acceleration, - ) - elif self.subsampling_scheme == "circus-spiral": - mask = self.circus_spiral_mask( - shape=(num_rows, num_cols), - acceleration=acceleration, + center_fraction, acceleration = self.choose_acceleration() + + if center_fraction == 0: + acs_mask = [] + mask = [] + for _ in range(num_slc_or_time): + if self.subsampling_scheme == "circus-radial": + mask.append( + self.circus_radial_mask( + shape=(num_rows, num_cols), + acceleration=acceleration, + ) + ) + elif self.subsampling_scheme == "circus-spiral": + mask.append( + self.circus_spiral_mask( + shape=(num_rows, num_cols), + acceleration=acceleration, + ) + ) + acs_mask.append(self.circular_centered_mask(mask[-1])) + mask = torch.stack(mask, dim=0).squeeze() + acs_mask = torch.stack(acs_mask, dim=0).squeeze() + acs_mask = reshape_array_to_shape(acs_mask, shape)[None].bool() + if return_acs: + return acs_mask + + else: + acs_mask = centered_disk_mask((num_rows, num_cols), center_fraction) + num_low_freqs = acs_mask.sum() + adjusted_accel = (acceleration * (num_low_freqs - num_rows * num_cols)) / ( + num_low_freqs * acceleration - num_rows * num_cols ) - if return_acs: - return self.circular_centered_mask(mask).unsqueeze(0).unsqueeze(-1) + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: + acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0) + + acs_mask = torch.from_numpy(reshape_array_to_shape(acs_mask, shape)[np.newaxis]).bool() - return mask.unsqueeze(0).unsqueeze(-1) + if return_acs: + return acs_mask + + mask = [] + for _ in range(num_slc_or_time): + if self.subsampling_scheme == "circus-radial": + mask.append( + self.circus_radial_mask( + shape=(num_rows, num_cols), + acceleration=adjusted_accel, + ) + ) + elif self.subsampling_scheme == "circus-spiral": + mask.append( + self.circus_spiral_mask( + shape=(num_rows, num_cols), + acceleration=adjusted_accel, + ) + ) + + mask = torch.stack(mask, dim=0).squeeze() + return reshape_array_to_shape(mask, shape)[np.newaxis].bool() | acs_mask class RadialMaskFunc(CIRCUSMaskFunc): @@ -1317,14 +1666,26 @@ class RadialMaskFunc(CIRCUSMaskFunc): ---------- accelerations : Union[list[Number], tuple[Number, ...]] Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]], optional + Fraction (< 1.0) of low-frequency samples to be retained. If None, it will calculate the acs mask + based on the maximum masked disk in the generated mask (with a tolerance).Default: None. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ def __init__( self, accelerations: Union[list[Number], tuple[Number, ...]], + center_fractions: Optional[Union[list[float], tuple[float, ...]]] = None, uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, ) -> None: """Inits :class:`RadialMaskFunc`. @@ -1332,12 +1693,25 @@ def __init__( ---------- accelerations : Union[list[Number], tuple[Number, ...]] Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]], optional + Fraction (< 1.0) of low-frequency samples to be retained. If None, it will calculate the acs mask + based on the maximum masked disk in the generated mask (with a tolerance).Default: None. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ super().__init__( accelerations=accelerations, + center_fractions=center_fractions, subsampling_scheme=CIRCUSSamplingMode.CIRCUS_RADIAL, + uniform_range=uniform_range, + mode=mode, ) @@ -1348,14 +1722,26 @@ class SpiralMaskFunc(CIRCUSMaskFunc): ---------- accelerations : Union[list[Number], tuple[Number, ...]] Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]], optional + Fraction (< 1.0) of low-frequency samples to be retained. If None, it will calculate the acs mask + based on the maximum masked disk in the generated mask (with a tolerance).Default: None. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ def __init__( self, accelerations: Union[list[Number], tuple[Number, ...]], + center_fractions: Optional[Union[list[float], tuple[float, ...]]] = None, uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, ) -> None: """Inits :class:`SpiralMaskFunc`. @@ -1363,12 +1749,25 @@ def __init__( ---------- accelerations : Union[list[Number], tuple[Number, ...]] Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]], optional + Fraction (< 1.0) of low-frequency samples to be retained. If None, it will calculate the acs mask + based on the maximum masked disk in the generated mask (with a tolerance).Default: None. uniform_range : bool If True then an acceleration will be uniformly sampled between the two values. Default: False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ super().__init__( accelerations=accelerations, + center_fractions=center_fractions, subsampling_scheme=CIRCUSSamplingMode.CIRCUS_SPIRAL, + uniform_range=uniform_range, + mode=mode, ) @@ -1384,6 +1783,13 @@ class VariableDensityPoissonMaskFunc(BaseMaskFunc): For center_scale='r', then a centered disk area with radius equal to :math:`R = \sqrt{{n_r}^2 + {n_c}^2} \\times r` will be fully sampled, where :math:`n_r` and :math:`n_c` denote the input shape. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. crop_corner : bool, optional If True mask will be disk. Default: False. max_attempts : int, optional @@ -1415,6 +1821,7 @@ def __init__( self, accelerations: Union[list[Number], tuple[Number, ...]], center_fractions: Union[list[float], tuple[float, ...]], + mode: MaskFuncMode = MaskFuncMode.STATIC, crop_corner: Optional[bool] = False, max_attempts: Optional[int] = 10, tol: Optional[float] = 0.2, @@ -1431,6 +1838,13 @@ def __init__( For center_scale='r', then a centered disk area with radius equal to :math:`R = \sqrt{{n_r}^2 + {n_c}^2} \times r` will be fully sampled, where :math:`n_r` and :math:`n_c` denote the input shape. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. crop_corner : bool, optional If True mask will be disk. Default: False. max_attempts : int, optional @@ -1445,6 +1859,7 @@ def __init__( accelerations=accelerations, center_fractions=center_fractions, uniform_range=False, + mode=mode, ) self.crop_corner = crop_corner self.max_attempts = max_attempts @@ -1479,18 +1894,26 @@ def mask_func( mask : torch.Tensor The sampling mask of shape (1, shape[0], shape[1], 1). """ - if len(shape) < 3: - raise ValueError("Shape should have 3 or more dimensions") - num_rows, num_cols = shape[:2] + num_rows, num_cols = shape[-3:-1] + num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): + self.rng.seed(integerize_seed(seed)) + center_fraction, acceleration = self.choose_acceleration() + if return_acs: - return torch.from_numpy( - centered_disk_mask((num_rows, num_cols), center_fraction)[np.newaxis, ..., np.newaxis] - ).bool() - mask = self.poisson(num_rows, num_cols, center_fraction, acceleration, integerize_seed(seed)) - return torch.from_numpy(mask[np.newaxis, ..., np.newaxis]).bool() + acs_mask = centered_disk_mask((num_rows, num_cols), center_fraction) + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: + acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0) + return torch.from_numpy(reshape_array_to_shape(acs_mask, shape)[np.newaxis]).bool() + + mask = [] + for _ in range(num_slc_or_time): + mask.append(self.poisson(num_rows, num_cols, center_fraction, acceleration, self.rng.randint(1e5))) + mask = np.stack(mask, axis=0).squeeze() + + return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]).bool() def poisson( self, @@ -1578,6 +2001,13 @@ class Gaussian1DMaskFunc(CartesianVerticalMaskFunc): Fraction of low-frequency columns (float < 1.0) to be retained. uniform_range : bool, optional If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ def __init__( @@ -1585,6 +2015,7 @@ def __init__( accelerations: Union[list[Number], tuple[Number, ...]], center_fractions: Union[list[float], tuple[float, ...]], uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, ) -> None: """Inits :class:`Gaussian1DMaskFunc`. @@ -1596,11 +2027,19 @@ def __init__( Fraction of low-frequency columns (float < 1.0) to be retained. uniform_range : bool, optional If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ super().__init__( accelerations=accelerations, center_fractions=center_fractions, uniform_range=uniform_range, + mode=mode, ) def mask_func( @@ -1628,25 +2067,42 @@ def mask_func( mask : torch.Tensor The sampling mask. """ - if len(shape) < 3: - raise ValueError("Shape should have 3 or more dimensions") + + num_cols = shape[-2] + num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): - num_cols = shape[-2] + self.rng.seed(integerize_seed(seed)) center_fraction, acceleration = self.choose_acceleration() num_low_freqs = int(round(num_cols * center_fraction)) mask = self.center_mask_func(num_cols, num_low_freqs).astype(int) + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: + mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0) + if return_acs: return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) # Calls cython function nonzero_count = int(np.round(num_cols / acceleration - num_low_freqs - 1)) - gaussian_mask_1d( - nonzero_count, num_cols, num_cols // 2, 6 * np.sqrt(num_cols // 2), mask, integerize_seed(seed) - ) + + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: + for i in range(num_slc_or_time): + gaussian_mask_1d( + nonzero_count, + num_cols, + num_cols // 2, + 6 * np.sqrt(num_cols // 2), + mask[i], + self.rng.randint(1e5), + ) + mask = mask.squeeze() + else: + gaussian_mask_1d( + nonzero_count, num_cols, num_cols // 2, 6 * np.sqrt(num_cols // 2), mask, self.rng.randint(1e5) + ) return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask).astype(bool)) @@ -1664,6 +2120,13 @@ class Gaussian2DMaskFunc(BaseMaskFunc): Fraction of low-frequency samples (float < 1.0) to be retained. uniform_range : bool, optional If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ def __init__( @@ -1671,6 +2134,7 @@ def __init__( accelerations: Union[list[Number], tuple[Number, ...]], center_fractions: Union[list[float], tuple[float, ...]], uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, ) -> None: """Inits :class:`Gaussian2DMaskFunc`. @@ -1682,11 +2146,19 @@ def __init__( Fraction of low-frequency samples (float < 1.0) to be retained. uniform_range : bool, optional If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. """ super().__init__( accelerations=accelerations, center_fractions=center_fractions, uniform_range=uniform_range, + mode=mode, ) def mask_func( @@ -1714,33 +2186,620 @@ def mask_func( mask : torch.Tensor The sampling mask. """ - if len(shape) < 3: - raise ValueError("Shape should have 3 or more dimensions") - - num_rows, num_cols = shape[:2] + num_rows, num_cols = shape[-3:-1] + num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): + self.rng.seed(integerize_seed(seed)) + center_fraction, acceleration = self.choose_acceleration() mask = centered_disk_mask((num_rows, num_cols), center_fraction) + + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: + mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0) + if return_acs: - return torch.from_numpy(mask[np.newaxis, ..., np.newaxis]).bool() + return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]).bool() - # Calls cython function - nonzero_count = int(np.round(num_cols * num_rows / acceleration - mask.sum() - 1)) std = 6 * np.array([np.sqrt(num_rows // 2), np.sqrt(num_cols // 2)], dtype=float) - gaussian_mask_2d( - nonzero_count, num_rows, num_cols, num_rows // 2, num_cols // 2, std, mask, integerize_seed(seed) + + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: + for i in range(num_slc_or_time): + # Calls cython function + gaussian_mask_2d( + int(np.round(num_cols * num_rows / acceleration - mask[i].sum() - 1)), + num_rows, + num_cols, + num_rows // 2, + num_cols // 2, + std, + mask[i], + self.rng.randint(1e5), + ) + mask = mask.squeeze() + else: + nonzero_count = int(np.round(num_cols * num_rows / acceleration - mask.sum() - 1)) + # Calls cython function + gaussian_mask_2d( + nonzero_count, num_rows, num_cols, num_rows // 2, num_cols // 2, std, mask, self.rng.randint(1e5) + ) + + return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]).bool() + + +class KtBaseMaskFunc(BaseMaskFunc): + """Base class for kt mask functions. + + Parameters + ---------- + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]] + Fraction of low-frequency columns (float < 1.0) or number of low-frequence columns (integer) to be retained. + uniform_range : bool, optional + If True then an acceleration will be uniformly sampled between the two values, by default False. + """ + + def __init__( + self, + accelerations: Union[list[Number], tuple[Number, ...]], + center_fractions: Union[list[float], tuple[float, ...]], + uniform_range: bool = False, + ) -> None: + """Inits :class:`KtBaseMaskFunc`. + + Parameters + ---------- + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]] + Fraction of low-frequency columns (float < 1.0) or number of low-frequence columns (integer) to be retained. + uniform_range : bool, optional + If True then an acceleration will be uniformly sampled between the two values, by default False. + """ + super().__init__( + accelerations=accelerations, + center_fractions=center_fractions, + uniform_range=uniform_range, + mode=MaskFuncMode.DYNAMIC, + ) + + @staticmethod + def zero_pad_to_center(array: np.ndarray, target_shape: tuple[int, ...]) -> np.ndarray: + """Zero pads an array to the target shape around its center. + + Parameters + ---------- + array : ndarray + The input array. + target_shape : tuple of int + The target shape for each dimension. + + Returns + ------- + ndarray + The zero-padded array. + """ + current_shape = list(array.shape) + + # Extend current_shape if it has fewer dimensions than target_shape + if len(current_shape) < len(target_shape): + current_shape.extend([1] * (len(target_shape) - len(current_shape))) + + # If the shapes are already the same, return the original array + if all(current_shape[i] == target_shape[i] for i in range(len(target_shape))): + return array + + # Create an array of zeros with the target shape + padded_array = np.zeros(target_shape, dtype=array.dtype) + + # Calculate the slices for inserting the original array into the padded array + insert_slices = tuple( + slice((target_dim - current_dim) // 2, (target_dim - current_dim) // 2 + current_dim) + for target_dim, current_dim in zip(target_shape, current_shape) + ) + + # Insert the original array into the padded array + padded_array[insert_slices] = array + + return padded_array + + @staticmethod + def linear_indices_to_2d_coordinates(indices: np.ndarray, row_length: int) -> tuple[np.ndarray, np.ndarray]: + """Converts linear indices to 2D coordinates. + + Parameters + ---------- + indices : ndarray + The linear indices to convert. + row_length : int + The length of the rows in the 2D array. + + Returns + ------- + tuple of ndarray + The x and y coordinates. + """ + x_coords = indices - np.floor((indices - 1) / row_length) * row_length + y_coords = np.ceil(indices / row_length) + return x_coords.astype(int), y_coords.astype(int) + + @staticmethod + def find_nearest_empty_location(target_index: int, empty_indices: np.ndarray, row_length: int) -> int: + """Finds the nearest empty index to the target index in 2D space. + + Parameters + ---------- + target_index : int + The index of the target point. + empty_indices : ndarray + The indices of empty locations. + row_length : int + The length of the rows in the 2D array. + + Returns + ------- + int + The nearest empty index. + """ + x0, y0 = KtBaseMaskFunc.linear_indices_to_2d_coordinates(target_index, row_length) + x, y = KtBaseMaskFunc.linear_indices_to_2d_coordinates(empty_indices, row_length) + + distance_x = (x - x0) ** 2 + distance_y = (y - y0) ** 2 + distance_y = distance_y.astype(float) + distance_y[distance_y > np.finfo(float).eps] = np.inf # Preventing zero distance consideration + distance = np.sqrt(distance_x + distance_y) + + nearest_index = np.argmin(distance) + return empty_indices[nearest_index] + + @staticmethod + def resolve_duplicates_on_kt_grid( + phase: np.ndarray, time: np.ndarray, ny: int, nt: int + ) -> tuple[np.ndarray, np.ndarray]: + """Corrects overlapping trajectories in k-space by shifting points to the nearest vacant locations. + + Parameters + ---------- + phase : ndarray + The phase coordinates of the trajectories. + time : ndarray + The time coordinates of the trajectories. + ny : int + The number of phase encoding steps. + nt : int + The number of time encoding steps. + + Returns + ------- + tuple of ndarray + Corrected phase and time coordinates. + """ + phase_corrected = phase + np.ceil((ny + 1) / 2) + time_corrected = time + np.ceil((nt + 1) / 2) + trajectory_indices = (time_corrected - 1) * ny + phase_corrected + + unique_indices, counts = np.unique(trajectory_indices, return_counts=True) + repeated_values = unique_indices[counts != 1] + duplicate_indices = [] + + for value in repeated_values: + duplicates = np.where(trajectory_indices == value)[0] + duplicate_indices.extend(duplicates[1:]) + + empty_indices = np.setdiff1d(np.arange(1, ny * nt + 1), trajectory_indices) + + for i in range(len(duplicate_indices)): + new_index = KtBaseMaskFunc.find_nearest_empty_location( + trajectory_indices[duplicate_indices[i]], empty_indices, ny ) + trajectory_indices[duplicate_indices[i]] = new_index + empty_indices = np.setdiff1d(empty_indices, new_index) + + phase_corrected, time_corrected = KtBaseMaskFunc.linear_indices_to_2d_coordinates(trajectory_indices, ny) + phase_corrected = phase_corrected - np.ceil((ny + 1) / 2) + time_corrected = time_corrected - np.ceil((nt + 1) / 2) + + return phase_corrected, time_corrected + + @staticmethod + def crop_center(array, target_height, target_width): + """Crops the center of an array to the target height and width. + + Parameters + ---------- + array : ndarray + The input array. + target_height : int + The target height. + target_width : int + The target width. + + Returns + ------- + ndarray + The cropped array. + """ + target_shape = [target_height, target_width] + current_shape = list(array.shape) + + # Extend target_shape if it has fewer dimensions than current_shape + if len(target_shape) < len(current_shape): + target_shape.extend([1] * (len(current_shape) - len(target_shape))) + + # If the shapes are already the same, return the original array + if current_shape == target_shape: + return array + + # Calculate the slices for cropping the array + crop_slices = tuple( + slice((current_dim - target_dim) // 2, (current_dim - target_dim) // 2 + target_dim) + for current_dim, target_dim in zip(current_shape, target_shape) + ) - return torch.from_numpy(mask[np.newaxis, ..., np.newaxis]).bool() + # Crop the array + cropped_array = array[crop_slices] + return cropped_array + + +class KtRadialMaskFunc(KtBaseMaskFunc): + """Kt radial mask function. + + Parameters + ---------- + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]] + Fraction of low-frequency columns (float < 1.0) or number of low-frequence columns (integer) to be retained. + uniform_range : bool, optional + If True then an acceleration will be uniformly sampled between the two values, by default False. + crop_corner : bool, optional + If True, the mask is cropped to the corners. Default: True. + """ + + def __init__( + self, + accelerations: Union[list[Number], tuple[Number, ...]], + center_fractions: Union[list[float], tuple[float, ...]], + uniform_range: bool = False, + crop_corner: bool = True, + ) -> None: + """Inits :class:`KtRadialMaskFunc`. + + Parameters + ---------- + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]] + Fraction of low-frequency columns (float < 1.0) or number of low-frequence columns (integer) to be retained. + uniform_range : bool, optional + If True then an acceleration will be uniformly sampled between the two values, by default False. + crop_corner : bool, optional + If True, the mask is cropped to the corners. Default: True. + """ + super().__init__( + accelerations=accelerations, + center_fractions=center_fractions, + uniform_range=uniform_range, + ) + self.crop_corner = crop_corner + + def mask_func( + self, + shape: Union[list[int], tuple[int, ...]], + return_acs: bool = False, + seed: Optional[Union[int, Iterable[int]]] = None, + ) -> torch.Tensor: + """Creates a kt radial mask. + + Parameters + ---------- + shape : list or tuple of ints + The shape of the mask to be created. The shape should at least 3 dimensions. + Samples are drawn along the second last dimension. + return_acs : bool + Return the autocalibration signal region as a mask. + seed : int or iterable of ints or None (optional) + Seed for the random number generator. Setting the seed ensures the same mask is generated + each time for the same shape. Default: None. + + Returns + ------- + torch.Tensor + The sampling mask. + """ + if len(shape) not in [4, 5]: + raise ValueError("Shape should have 4 or 5 dimensions.") + + (nt, num_rows, num_cols) = shape[-4:-1] + + with temp_seed(self.rng, seed): + + center_fraction, acceleration = self.choose_acceleration() + num_low_freqs = int(round(num_cols * center_fraction)) + + offset_angle = self.rng.uniform(0, 360) + + acs_mask = self.zero_pad_to_center(np.ones((nt, num_low_freqs, num_low_freqs)), [nt, num_rows, num_cols]) + + if return_acs: + return torch.from_numpy(acs_mask.astype(bool)[np.newaxis, ..., np.newaxis]) + + adjusted_acceleration = (acceleration * (num_low_freqs**2 - num_rows * num_cols)) / ( + num_low_freqs**2 * acceleration - num_rows * num_cols + ) + + rate = 1 / adjusted_acceleration + beams = int(rate * np.mean([num_rows, num_cols])) # beams is the number of angles + + if self.crop_corner: + temp_size = max(num_rows, num_cols) + else: + temp_size = int(np.sqrt(num_rows**2 + num_cols**2)) + + aux = np.zeros((temp_size, temp_size)) + aux[int(temp_size / 2), :] = 1 + + mask = np.zeros((nt, num_rows, num_cols)) + for i in range(nt): + angles = np.linspace(0 + offset_angle * i, 180 + offset_angle * i + 1, beams) + mask_t = np.zeros((num_rows, num_cols)) + for ang in angles: + temp = self.crop_center(rotate(aux, ang, reshape=False, order=0), num_rows, num_cols) + mask_t += temp + mask[i] = mask_t + + mask = mask + acs_mask + mask = mask > 0 + + return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]) + + +class KtUniformMaskFunc(KtBaseMaskFunc): + """Kt uniform mask function. + + Parameters + ---------- + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]] + Fraction of low-frequency columns (float < 1.0) or number of low-frequence columns (integer) to be retained. + uniform_range : bool, optional + If True then an acceleration will be uniformly sampled between the two values, by default False. + """ + + def __init__( + self, + accelerations: Union[list[Number], tuple[Number, ...]], + center_fractions: Union[list[float], tuple[float, ...]], + uniform_range: bool = False, + ) -> None: + """Inits :class:`KtUniformMaskFunc`. + + Parameters + ---------- + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]] + Fraction of low-frequency columns (float < 1.0) or number of low-frequence columns (integer) to be retained. + uniform_range : bool, optional + If True then an acceleration will be uniformly sampled between the two values, by default False. + """ + super().__init__( + accelerations=accelerations, + center_fractions=center_fractions, + uniform_range=uniform_range, + ) + + def mask_func( + self, + shape: Union[list[int], tuple[int, ...]], + return_acs: bool = False, + seed: Optional[Union[int, Iterable[int]]] = None, + ) -> torch.Tensor: + """Creates a kt uniform mask. + + Parameters + ---------- + shape : list or tuple of ints + The shape of the mask to be created. The shape should at least 3 dimensions. + Samples are drawn along the second last dimension. + return_acs : bool + Return the autocalibration signal region as a mask. + seed : int or iterable of ints or None (optional) + Seed for the random number generator. Setting the seed ensures the same mask is generated + each time for the same shape. Default: None. + + Returns + ------- + torch.Tensor + The sampling mask. + """ + if len(shape) not in [4, 5]: + raise ValueError("Shape should have 4 or 5 dimensions.") + + (nt, num_rows, num_cols) = shape[-4:-1] + + with temp_seed(self.rng, seed): + + center_fraction, acceleration = self.choose_acceleration() + num_low_freqs = int(round(num_cols * center_fraction)) + + acs_mask = self.zero_pad_to_center(np.ones((nt, num_rows, num_low_freqs)), [nt, num_rows, num_cols]) + + if return_acs: + return torch.from_numpy(acs_mask.astype(bool)[np.newaxis, ..., np.newaxis]) + + adjusted_acceleration = int( + (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols) + ) + + ptmp = np.zeros(num_cols) + ttmp = np.zeros(nt) + + ptmp[ + np.round( + np.arange(self.rng.randint(0, adjusted_acceleration), num_cols, adjusted_acceleration) + ).astype(int) + ] = 1 + ttmp[ + np.round(np.arange(self.rng.randint(0, adjusted_acceleration), nt, adjusted_acceleration)).astype(int) + ] = 1 + + top_mat = toeplitz(ptmp, ttmp) + ind = np.where(top_mat.ravel())[0] + + ph = (ind % num_cols) - (num_cols // 2) + ti = (ind // num_cols) - (nt // 2) + + ph, ti = self.resolve_duplicates_on_kt_grid(ph, ti, num_cols, nt) + samp = np.zeros((num_cols, nt), dtype=int) + indices = num_cols * (ti + (nt // 2)) + (ph + (num_cols // 2)) + indices[indices <= 0] = 1 # Ensure indices are within bounds + samp.ravel()[indices.astype(int)] = 1 + + mask = np.tile(samp, (num_rows, 1, 1)).transpose(2, 0, 1) + mask = mask + acs_mask + mask = mask > 0 + + return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]) + + +class KtGaussian1DMaskFunc(KtBaseMaskFunc): + """Kt Gaussian 1D mask function. + + Parameters + ---------- + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]] + Fraction of low-frequency columns (float < 1.0) or number of low-frequence columns (integer) to be retained. + uniform_range : bool, optional + If True then an acceleration will be uniformly sampled between the two values, by default False. + alpha : float, optional + 0 < alpha < 1 controls sampling density; 0: uniform density, 1: maximally non-uniform density. + Default: 0.28. + std_scale : float, optional + The standard deviation scaling of the Gaussian envelope for sampling density. Default: 5.0. + """ + + def __init__( + self, + accelerations: Union[list[Number], tuple[Number, ...]], + center_fractions: Union[list[float], tuple[float, ...]], + uniform_range: bool = False, + alpha: float = 0.28, + std_scale: float = 5.0, + ) -> None: + """Inits :class:`KtGaussian1DMaskFunc`. + + Parameters + ---------- + accelerations : Union[list[Number], tuple[Number, ...]] + Amount of under-sampling. + center_fractions : Union[list[float], tuple[float, ...]] + Fraction of low-frequency columns (float < 1.0) or number of low-frequence columns (integer) to be retained. + uniform_range : bool, optional + If True then an acceleration will be uniformly sampled between the two values, by default False. + alpha : float, optional + 0 < alpha < 1 controls sampling density; 0: uniform density, 1: maximally non-uniform density. + Default: 0.28. + std_scale : float, optional + The standard deviation scaling of the Gaussian envelope for sampling density. Default: 5.0. + """ + super().__init__( + accelerations=accelerations, + center_fractions=center_fractions, + uniform_range=uniform_range, + ) + self.alpha = alpha + self.std_scale = std_scale + + def mask_func( + self, + shape: Union[list[int], tuple[int, ...]], + return_acs: bool = False, + seed: Optional[Union[int, Iterable[int]]] = None, + ) -> torch.Tensor: + """Creates a kt Gaussian 1D mask. + + Parameters + ---------- + shape : list or tuple of ints + The shape of the mask to be created. The shape should at least 3 dimensions. + Samples are drawn along the second last dimension. + return_acs : bool + Return the autocalibration signal region as a mask. + seed : int or iterable of ints or None (optional) + Seed for the random number generator. Setting the seed ensures the same mask is generated + each time for the same shape. Default: None. + + Returns + ------- + torch.Tensor + The sampling mask. + """ + if len(shape) not in [4, 5]: + raise ValueError("Shape should have 4 or 5 dimensions.") + + (nt, num_rows, num_cols) = shape[-4:-1] + + with temp_seed(self.rng, seed): + + center_fraction, acceleration = self.choose_acceleration() + num_low_freqs = int(round(num_cols * center_fraction)) + + acs_mask = self.zero_pad_to_center(np.ones((nt, num_rows, num_low_freqs)), [nt, num_rows, num_cols]) + + if return_acs: + return torch.from_numpy(acs_mask.astype(bool)[np.newaxis, ..., np.newaxis]) + + adjusted_acceleration = int( + (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols) + ) + + p1 = np.arange(-num_cols // 2, num_cols // 2) + t1 = [] + + tr = round(num_cols / adjusted_acceleration) # Number of readout lines per frame (temporal resolution) + ti = np.zeros(tr * nt, dtype=int) + ph = np.zeros(tr * nt, dtype=int) + + sigma = num_cols / self.std_scale # Std of the Gaussian envelope for sampling density + + prob = 0.1 + self.alpha / (1 - self.alpha + 1e-10) * np.exp(-((p1) ** 2) / (sigma**2)) + + ind = 0 + for i in range(-nt // 2, nt // 2): + a = np.where(np.array(t1) == i)[0] + n_tmp = tr - len(a) + prob_tmp = prob.copy() + prob_tmp[a] = 0 + p_tmp = self.rng.choice(np.arange(-num_cols // 2, num_cols // 2), n_tmp, p=prob_tmp / prob_tmp.sum()) + ti[ind : ind + n_tmp] = i + ph[ind : ind + n_tmp] = p_tmp + ind += n_tmp + + ph, ti = self.resolve_duplicates_on_kt_grid(ph, ti, num_cols, nt) + samp = np.zeros((num_cols, nt), dtype=int) + inds = num_cols * (ti + nt // 2) + (ph + num_cols // 2) + inds = inds.astype(int) + samp.ravel()[inds] = 1 + + mask = np.tile(samp, (num_rows, 1, 1)).transpose(2, 0, 1) + mask = mask + acs_mask + mask = mask > 0 + + return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]) def integerize_seed(seed: Union[None, tuple[int, ...], list[int]]) -> int: """Returns an integer seed. - If input is integer, will return it. If it's None, it will return a random integer in [0, 1e6). If it's a tuple - or list of integers, will return the integer part of the average. + If input is integer, will return the input. If input is None, will return a random integer seed. + If input is a tuple or list, will return a random integer seed based on the input. Can be useful for functions that take as input only integer seeds (e.g. cython functions). @@ -1754,13 +2813,15 @@ def integerize_seed(seed: Union[None, tuple[int, ...], list[int]]) -> int: out_seed: int Integer seed. """ - if seed is None: - return np.random.randint(0, 1e6) if isinstance(seed, int): return seed - if isinstance(seed, (tuple, list)): - return int(np.mean(seed)) - raise ValueError(f"Invalid seed type. Can be None, integer, or tuple or list of integers. Received {seed}.") + else: + rng = np.random.RandomState() + if seed is None: + return rng.randint(0, 1e6) + elif isinstance(seed, (tuple, list)): + with temp_seed(rng, seed): + return rng.randint(0, 1e6) def centered_disk_mask(shape: Union[list[int], tuple[int, ...]], center_scale: float) -> np.ndarray: @@ -1787,21 +2848,12 @@ def centered_disk_mask(shape: Union[list[int], tuple[int, ...]], center_scale: f return mask.astype(int) -class DictionaryMaskFunc(BaseMaskFunc): - def __init__(self, data_dictionary, **kwargs): # noqa - super().__init__(accelerations=None) - - self.data_dictionary = data_dictionary - - def mask_func(self, data, return_acs=False): - return self.data_dictionary[data] - - def build_masking_function( name: str, accelerations: Union[list[Number], tuple[Number, ...]], center_fractions: Optional[Union[list[Number], tuple[Number, ...]]] = None, uniform_range: bool = False, + mode: MaskFuncMode = MaskFuncMode.STATIC, **kwargs: dict[str, Any], ) -> BaseMaskFunc: """Builds a mask function. @@ -1818,6 +2870,13 @@ def build_masking_function( is True, then two values should be given, by default None. uniform_range : bool, optional If True then an acceleration will be uniformly sampled between the two values, by default False. + mode : MaskFuncMode, optional + Mode of the mask function. Can be MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, or MaskFuncMode.MULTISLICE. + If MaskFuncMode.STATIC, then a single mask is created independent of the requested shape, and will be + broadcasted to the shape by expanding other dimensions with 1, if applicable. If MaskFuncMode.DYNAMIC, + this expects the shape to have more then 3 dimensions, and the mask will be created for each time frame + along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each slice + along the fourth last dimension. Default: MaskFuncMode.STATIC. **kwargs : dict[str, Any], optional Additional keyword arguments to be passed to the mask function. These will be passed as keyword arguments to the mask function constructor. If the mask function constructor does not accept these arguments, they will @@ -1842,6 +2901,7 @@ def build_masking_function( { "center_fractions": center_fractions, "uniform_range": uniform_range, + "mode": mode, } ) # Now, iterate over the kwargs diff --git a/direct/common/subsample_config.py b/direct/common/subsample_config.py index ed49ad850..89cccbb35 100644 --- a/direct/common/subsample_config.py +++ b/direct/common/subsample_config.py @@ -1,20 +1,23 @@ -# coding=utf-8 # Copyright (c) DIRECT Contributors + +from __future__ import annotations + from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional from omegaconf import MISSING from direct.config.defaults import BaseConfig +from direct.types import MaskFuncMode @dataclass class MaskingConfig(BaseConfig): name: str = MISSING - accelerations: Tuple[int, ...] = (5,) # Ideally Union[float, int]. - center_fractions: Optional[Tuple[float, ...]] = (0.1,) # Ideally Optional[Tuple[float, ...]] + accelerations: tuple[float, ...] = (5.0,) + center_fractions: Optional[tuple[float, ...]] = (0.1,) uniform_range: bool = False - image_center_crop: bool = False + mode: MaskFuncMode = MaskFuncMode.STATIC - val_accelerations: Tuple[int, ...] = (5, 10) - val_center_fractions: Optional[Tuple[float, ...]] = (0.1, 0.05) + val_accelerations: tuple[float, ...] = (5.0, 10.0) + val_center_fractions: Optional[tuple[float, ...]] = (0.1, 0.05) diff --git a/direct/data/mri_transforms.py b/direct/data/mri_transforms.py index 606c56e5f..0408a3654 100644 --- a/direct/data/mri_transforms.py +++ b/direct/data/mri_transforms.py @@ -317,7 +317,7 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: Sample with `sampling_mask` key. """ if not self.shape: - shape = sample["kspace"].shape[-3:] + shape = sample["kspace"].shape[1:] elif any(_ is None for _ in self.shape): # Allow None as values. kspace_shape = list(sample["kspace"].shape[1:-1]) shape = tuple(_ if _ else kspace_shape[idx] for idx, _ in enumerate(self.shape)) + (2,) @@ -328,9 +328,6 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: sampling_mask = self.mask_func(shape=shape, seed=seed, return_acs=False) - if sample["kspace"].ndim == 5: - sampling_mask = sampling_mask.unsqueeze(0) - if "padding" in sample: sampling_mask = T.apply_padding(sampling_mask, sample["padding"]) diff --git a/direct/types.py b/direct/types.py index e3bda2080..c0a8eddf7 100644 --- a/direct/types.py +++ b/direct/types.py @@ -8,6 +8,7 @@ from enum import Enum from typing import NewType, Union +import numpy as np import torch from omegaconf.omegaconf import DictConfig from torch import nn as nn @@ -19,6 +20,7 @@ FileOrUrl = NewType("FileOrUrl", PathOrString) HasStateDict = Union[nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, GradScaler] TensorOrNone = Union[None, torch.Tensor] +TensorOrNdarray = Union[torch.Tensor, np.ndarray] class DirectEnum(str, Enum): @@ -57,6 +59,16 @@ class TransformKey(DirectEnum): SAMPLING_MASK = "sampling_mask" ACS_MASK = "acs_mask" SCALING_FACTOR = "scaling_factor" + REFERENCE_IMAGE = "reference_image" + MOVING_IMAGE = "moving_image" + WARPED_IMAGE = "warped_image" + DISPLACEMENT_FIELD = "displacement_field" + + +class MaskFuncMode(DirectEnum): + STATIC = "static" + DYNAMIC = "dynamic" + MULTISLICE = "multislice" class IntegerListOrTupleStringMeta(type): diff --git a/direct/utils/__init__.py b/direct/utils/__init__.py index 969140469..1f373deb2 100644 --- a/direct/utils/__init__.py +++ b/direct/utils/__init__.py @@ -535,3 +535,59 @@ def dict_flatten(in_dict: DictOrDictConfig, dict_out: Optional[DictOrDictConfig] continue dict_out[k] = v return dict_out + + +def reshape_array_to_shape(array: np.ndarray, requested_shape: Tuple[int, ...]) -> np.ndarray: + """Reshapes the given array to match the requested shape by adding dimensions of size 1 where necessary. + + Parameters + ---------- + array : np.ndarray + The input array to be reshaped. + requested_shape tuple of ints + The desired shape of the output array. + + Returns + ------- + np.ndarray + The reshaped array with the requested shape. + + Example + ------- + >>> array1 = np.random.rand(4, 5) + >>> requested_shape1 = (4, 5, 1) + >>> result1 = reshape_array_to_shape(array1, requested_shape1) + >>> print(result1.shape) # Output: (4, 5, 1) + + >>> array2 = np.random.rand(4, 5) + >>> requested_shape2 = (1, 4, 5, 1) + >>> result2 = reshape_array_to_shape(array2, requested_shape2) + >>> print(result2.shape) # Output: (1, 4, 5, 1) + + >>> array3 = np.random.rand(2, 4, 5) + >>> requested_shape3 = (2, 4, 5, 1) + >>> result3 = reshape_array_to_shape(array3, requested_shape3) + >>> print(result3.shape) # Output: (2, 4, 5, 1) + """ + + # Get the current shape of the array + current_shape = array.shape + + # Check if the current shape already matches the requested shape + if current_shape == requested_shape: + return array + + # Initialize a new shape list with ones + new_shape = [1] * len(requested_shape) + + # Fill in the new shape list with dimensions from the current shape where appropriate + j = 0 # Index for current shape + for i in range(len(requested_shape)): + if j < len(current_shape) and requested_shape[i] == current_shape[j]: + new_shape[i] = current_shape[j] + j += 1 + + # Reshape the array to the new shape + reshaped_array = np.reshape(array, new_shape) + + return reshaped_array diff --git a/tests/test_train.py b/tests/test_train.py index f64cbb40f..997882076 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -25,12 +25,13 @@ ) from direct.launch import launch from direct.train import setup_train +from direct.types import MaskFuncMode def create_test_transform_cfg(transforms_type): transforms_config = TransformsConfig( normalization=NormalizationTransformConfig(scaling_key="masked_kspace"), - masking=MaskingConfig(name="FastMRIRandom"), + masking=MaskingConfig(name="FastMRIRandom", mode=MaskFuncMode.STATIC), cropping=CropTransformConfig(crop="(32, 32)"), sensitivity_map_estimation=SensitivityMapEstimationTransformConfig(estimate_sensitivity_maps=True), transforms_type=transforms_type, diff --git a/tests/tests_common/test_subsample.py b/tests/tests_common/test_subsample.py index 0c0ef860c..3c76a106a 100644 --- a/tests/tests_common/test_subsample.py +++ b/tests/tests_common/test_subsample.py @@ -32,9 +32,13 @@ ([0.2, 0.4], [4, 8], 2, 368), ], ) -def test_mask_reuse(mask_func, center_fracs, accelerations, batch_size, dim): - mask_func = mask_func(center_fractions=center_fracs, accelerations=accelerations) - shape = (batch_size, dim, dim, 2) +@pytest.mark.parametrize( + "mode", + [MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE], +) +def test_mask_reuse(mask_func, center_fracs, accelerations, batch_size, dim, mode): + mask_func = mask_func(center_fractions=center_fracs, accelerations=accelerations, mode=mode) + shape = (batch_size, dim, dim, 2) if mode == MaskFuncMode.STATIC else (batch_size, dim // 100, dim, dim, 2) mask1 = mask_func(shape, seed=123) mask2 = mask_func(shape, seed=123) mask3 = mask_func(shape, seed=123) @@ -50,15 +54,19 @@ def test_mask_reuse(mask_func, center_fracs, accelerations, batch_size, dim): ], ) @pytest.mark.parametrize( - "accelerations, batch_size, dim", + "center_fracs, accelerations, batch_size, dim", [ - ([4], 4, 320), - ([4, 8], 2, 368), + ([0.2], [4], 4, 320), + ([0.2, 0.4], [4, 8], 2, 368), ], ) -def test_mask_reuse_circus(mask_func, accelerations, batch_size, dim): - mask_func = mask_func(accelerations=accelerations) - shape = (batch_size, dim, dim, 2) +@pytest.mark.parametrize( + "mode", + [MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE], +) +def test_mask_reuse_circus(mask_func, center_fracs, accelerations, batch_size, dim, mode): + mask_func = mask_func(accelerations=accelerations, center_fractions=center_fracs, mode=mode) + shape = (batch_size, dim, dim, 2) if mode == MaskFuncMode.STATIC else (batch_size, dim // 100, dim, dim, 2) mask1 = mask_func(shape, seed=123) mask2 = mask_func(shape, seed=123) mask3 = mask_func(shape, seed=123) @@ -81,9 +89,34 @@ def test_mask_reuse_circus(mask_func, accelerations, batch_size, dim): ([30, 20], [4, 8], 2, 368), ], ) -def test_mask_reuse_cartesian(mask_func, center_fracs, accelerations, batch_size, dim): +@pytest.mark.parametrize( + "mode", + [MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE], +) +def test_mask_reuse_cartesian(mask_func, center_fracs, accelerations, batch_size, dim, mode): + mask_func = mask_func(center_fractions=center_fracs, accelerations=accelerations, mode=mode) + shape = (batch_size, dim, dim, 2) if mode == MaskFuncMode.STATIC else (batch_size, dim // 100, dim, dim, 2) + mask1 = mask_func(shape, seed=123) + mask2 = mask_func(shape, seed=123) + mask3 = mask_func(shape, seed=123) + assert torch.all(mask1 == mask2) + assert torch.all(mask2 == mask3) + + +@pytest.mark.parametrize( + "mask_func", + [KtGaussian1DMaskFunc, KtRadialMaskFunc, KtUniformMaskFunc], +) +@pytest.mark.parametrize( + "center_fracs, accelerations, batch_size, shape", + [ + ([0.2], [4], 4, [10, 200, 300]), + ([0.2, 0.4], [4, 8], 2, [4, 220, 200]), + ], +) +def test_mask_reuse_kt(mask_func, center_fracs, accelerations, batch_size, shape): mask_func = mask_func(center_fractions=center_fracs, accelerations=accelerations) - shape = (batch_size, dim, dim, 2) + shape = (batch_size, *shape, 2) mask1 = mask_func(shape, seed=123) mask2 = mask_func(shape, seed=123) mask3 = mask_func(shape, seed=123) @@ -102,13 +135,18 @@ def test_mask_reuse_cartesian(mask_func, center_fracs, accelerations, batch_size ([0.2, 0.4], [4, 8], 2, 368), ], ) -def test_cartesian_mask_low_freqs(mask_func, center_fracs, accelerations, batch_size, dim): - mask_func = mask_func(center_fractions=center_fracs, accelerations=accelerations) - shape = (batch_size, dim, dim, 2) +@pytest.mark.parametrize( + "mode", + [MaskFuncMode.STATIC, MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE], +) +def test_cartesian_mask_low_freqs(mask_func, center_fracs, accelerations, batch_size, dim, mode): + mask_func = mask_func(center_fractions=center_fracs, accelerations=accelerations, mode=mode) + shape = (batch_size, dim, dim, 2) if mode == MaskFuncMode.STATIC else (batch_size, dim // 100, dim, dim, 2) mask = mask_func(shape, seed=123) + mask_shape = [1] * (len(shape) + 1) - mask_shape[-2] = dim - mask_shape[-3] = dim + mask_shape[-3:-1] = shape[-3:-1] + mask_shape[-4] = mask_shape[-4] if mode == MaskFuncMode.STATIC else shape[-4] assert list(mask.shape) == mask_shape @@ -126,21 +164,26 @@ def test_cartesian_mask_low_freqs(mask_func, center_fracs, accelerations, batch_ [FastMRIRandomMaskFunc, FastMRIEquispacedMaskFunc, FastMRIMagicMaskFunc, Gaussian1DMaskFunc], ) @pytest.mark.parametrize( - "shape, center_fractions, accelerations", + "shape, center_fractions, accelerations, mode", [ - ([4, 32, 32, 2], [0.08], [4]), - ([2, 64, 64, 2], [0.04, 0.08], [8, 4]), + ([4, 32, 32, 2], [0.08], [4], MaskFuncMode.STATIC), + ([2, 64, 64, 2], [0.04, 0.08], [8, 4], MaskFuncMode.STATIC), + ([4, 5, 32, 32, 2], [0.08], [4], MaskFuncMode.STATIC), + ([4, 5, 32, 32, 2], [0.08], [4], MaskFuncMode.DYNAMIC), + ([4, 5, 32, 32, 2], [0.04], [8], MaskFuncMode.MULTISLICE), ], ) -def test_apply_mask_cartesian(mask_func, shape, center_fractions, accelerations): - mask_func = mask_func(center_fractions=center_fractions, accelerations=accelerations) +def test_apply_mask_cartesian(mask_func, shape, center_fractions, accelerations, mode): + mask_func = mask_func(center_fractions=center_fractions, accelerations=accelerations, mode=mode) mask = mask_func(shape[1:], seed=123) acs_mask = mask_func(shape[1:], seed=123, return_acs=True) - expected_mask_shape = (1, shape[1], shape[2], 1) + expected_mask_shape = [1] * len(shape) + expected_mask_shape[-3:-1] = shape[-3:-1] + expected_mask_shape[-4] = expected_mask_shape[-4] if mode == MaskFuncMode.STATIC else shape[-4] assert mask.max() == 1 assert mask.min() == 0 - assert mask.shape == expected_mask_shape + assert mask.shape == tuple(expected_mask_shape) assert np.allclose(mask & acs_mask, acs_mask) @@ -168,14 +211,17 @@ def test_same_across_volumes_mask_cartesian_fraction_center(mask_func, shape, ce [CartesianEquispacedMaskFunc, CartesianMagicMaskFunc, CartesianRandomMaskFunc], ) @pytest.mark.parametrize( - "shape, center_fractions, accelerations", + "shape, center_fractions, accelerations, mode", [ - ([4, 32, 32, 2], [6], [4]), - ([2, 64, 64, 2], [4, 6], [8, 4]), + ([4, 32, 32, 2], [6], [4], MaskFuncMode.STATIC), + ([2, 64, 64, 2], [4, 6], [8, 4], MaskFuncMode.STATIC), + ([4, 5, 32, 32, 2], [6], [4], MaskFuncMode.STATIC), + ([4, 5, 32, 32, 2], [6], [4], MaskFuncMode.DYNAMIC), + ([4, 5, 32, 32, 2], [6], [4], MaskFuncMode.MULTISLICE), ], ) -def test_same_across_volumes_mask_cartesian(mask_func, shape, center_fractions, accelerations): - mask_func = mask_func(center_fractions=center_fractions, accelerations=accelerations) +def test_same_across_volumes_mask_cartesian(mask_func, shape, center_fractions, accelerations, mode): + mask_func = mask_func(center_fractions=center_fractions, accelerations=accelerations, mode=mode) num_slices = shape[0] masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] @@ -232,86 +278,116 @@ def test_same_across_volumes_mask_calgary_campinas(shape, accelerations): @pytest.mark.parametrize( - "shape, accelerations", + "shape, accelerations, center_fractions, mode", [ - ([4, 32, 32, 2], [4]), - ([2, 64, 64, 2], [8, 4]), + ([4, 32, 32, 2], [4], [0.08], MaskFuncMode.STATIC), + ([2, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.STATIC), + ([4, 3, 32, 32, 2], [4], None, MaskFuncMode.DYNAMIC), + ([2, 3, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.DYNAMIC), + ([4, 1, 32, 32, 2], [4], None, MaskFuncMode.STATIC), + ([2, 3, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.STATIC), + ([4, 32, 32, 2], [4], None, MaskFuncMode.STATIC), + ([2, 64, 64, 2], [8, 4], None, MaskFuncMode.STATIC), ], ) -def test_apply_mask_radial(shape, accelerations): - mask_func = RadialMaskFunc( - accelerations=accelerations, - ) +def test_apply_mask_radial(shape, accelerations, center_fractions, mode): + mask_func = RadialMaskFunc(accelerations=accelerations, center_fractions=center_fractions, mode=mode) mask = mask_func(shape[1:], seed=123) acs_mask = mask_func(shape[1:], seed=123, return_acs=True) - expected_mask_shape = (1, shape[1], shape[2], 1) + expected_mask_shape = [1] * len(shape) + expected_mask_shape[-3:-1] = shape[-3:-1] + expected_mask_shape[-4] = expected_mask_shape[-4] if mode == MaskFuncMode.STATIC else shape[-4] assert mask.max() == 1 assert mask.min() == 0 - assert mask.shape == expected_mask_shape + assert mask.shape == tuple(expected_mask_shape) assert np.allclose(mask & acs_mask, acs_mask) @pytest.mark.parametrize( - "shape, accelerations", + "shape, accelerations, center_fractions, mode", [ - ([4, 32, 32, 2], [4]), - ([2, 64, 64, 2], [8, 4]), + ([4, 32, 32, 2], [4], [0.08], MaskFuncMode.STATIC), + ([2, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.STATIC), + ([4, 32, 32, 2], [4], None, MaskFuncMode.STATIC), + ([2, 64, 64, 2], [8, 4], None, MaskFuncMode.STATIC), + ([1, 3, 32, 32, 2], [4], [0.08], MaskFuncMode.STATIC), + ([1, 2, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.STATIC), + ([1, 3, 32, 32, 2], [4], None, MaskFuncMode.STATIC), + ([1, 3, 64, 64, 2], [8, 4], None, MaskFuncMode.STATIC), + ([1, 3, 32, 32, 2], [4], [0.08], MaskFuncMode.DYNAMIC), + ([1, 2, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.DYNAMIC), + ([1, 3, 32, 32, 2], [4], None, MaskFuncMode.DYNAMIC), + ([1, 3, 64, 64, 2], [8, 4], None, MaskFuncMode.DYNAMIC), ], ) -def test_same_across_volumes_mask_radial(shape, accelerations): - mask_func = RadialMaskFunc( - accelerations=accelerations, - ) - num_slices = shape[0] - masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] +def test_same_across_volumes_mask_radial(shape, accelerations, center_fractions, mode): + mask_func = RadialMaskFunc(accelerations=accelerations, center_fractions=center_fractions, mode=mode) + batch_sz = shape[0] + masks = [mask_func(shape[1:], seed=123) for _ in range(batch_sz)] - assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(num_slices - 1)) + assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(batch_sz - 1)) @pytest.mark.parametrize( - "shape, accelerations", + "shape, accelerations, center_fractions, mode", [ - ([4, 32, 32, 2], [4]), - ([2, 64, 64, 2], [8, 4]), + ([4, 32, 32, 2], [4], [0.08], MaskFuncMode.STATIC), + ([2, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.STATIC), + ([4, 3, 32, 32, 2], [4], None, MaskFuncMode.DYNAMIC), + ([2, 3, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.DYNAMIC), + ([4, 1, 32, 32, 2], [4], None, MaskFuncMode.STATIC), + ([2, 3, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.STATIC), + ([4, 32, 32, 2], [4], None, MaskFuncMode.STATIC), + ([2, 64, 64, 2], [8, 4], None, MaskFuncMode.STATIC), ], ) -def test_apply_mask_spiral(shape, accelerations): - mask_func = SpiralMaskFunc( - accelerations=accelerations, - ) +def test_apply_mask_spiral(shape, accelerations, center_fractions, mode): + mask_func = SpiralMaskFunc(accelerations=accelerations, center_fractions=center_fractions, mode=mode) mask = mask_func(shape[1:], seed=123) acs_mask = mask_func(shape[1:], seed=123, return_acs=True) - expected_mask_shape = (1, shape[1], shape[2], 1) + expected_mask_shape = [1] * len(shape) + expected_mask_shape[-3:-1] = shape[-3:-1] + expected_mask_shape[-4] = expected_mask_shape[-4] if mode == MaskFuncMode.STATIC else shape[-4] assert mask.max() == 1 assert mask.min() == 0 - assert mask.shape == expected_mask_shape + assert mask.shape == tuple(expected_mask_shape) assert np.allclose(mask & acs_mask, acs_mask) @pytest.mark.parametrize( - "shape, accelerations", + "shape, accelerations, center_fractions, mode", [ - ([4, 32, 32, 2], [4]), - ([2, 64, 64, 2], [8, 4]), + ([4, 32, 32, 2], [4], [0.08], MaskFuncMode.STATIC), + ([2, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.STATIC), + ([4, 32, 32, 2], [4], None, MaskFuncMode.STATIC), + ([2, 64, 64, 2], [8, 4], None, MaskFuncMode.STATIC), + ([1, 3, 32, 32, 2], [4], [0.08], MaskFuncMode.STATIC), + ([1, 2, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.STATIC), + ([1, 3, 32, 32, 2], [4], None, MaskFuncMode.STATIC), + ([1, 3, 64, 64, 2], [8, 4], None, MaskFuncMode.STATIC), + ([1, 3, 32, 32, 2], [4], [0.08], MaskFuncMode.DYNAMIC), + ([1, 2, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.DYNAMIC), + ([1, 3, 32, 32, 2], [4], None, MaskFuncMode.DYNAMIC), + ([1, 3, 64, 64, 2], [8, 4], None, MaskFuncMode.DYNAMIC), ], ) -def test_same_across_volumes_mask_spiral(shape, accelerations): - mask_func = SpiralMaskFunc( - accelerations=accelerations, - ) - num_slices = shape[0] - masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] +def test_same_across_volumes_mask_spiral(shape, accelerations, center_fractions, mode): + mask_func = SpiralMaskFunc(accelerations=accelerations, center_fractions=center_fractions, mode=mode) + batch_sz = shape[0] + masks = [mask_func(shape[1:], seed=123) for _ in range(batch_sz)] - assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(num_slices - 1)) + assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(batch_sz - 1)) @pytest.mark.parametrize( - "shape, accelerations, center_fractions", + "shape, accelerations, center_fractions, mode", [ - ([4, 32, 32, 2], [4], [0.08]), - ([2, 64, 64, 2], [8, 4], [0.04, 0.08]), + ([4, 32, 32, 2], [4], [0.08], MaskFuncMode.STATIC), + ([2, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.STATIC), + ([2, 3, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.DYNAMIC), + ([2, 3, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.STATIC), ], ) @pytest.mark.parametrize( @@ -326,59 +402,71 @@ def test_same_across_volumes_mask_spiral(shape, accelerations): tuple(np.random.randint(100000, 1000000, 30)), ], ) -def test_apply_mask_poisson(shape, accelerations, center_fractions, seed): +def test_apply_mask_poisson(shape, accelerations, center_fractions, seed, mode): mask_func = VariableDensityPoissonMaskFunc( accelerations=accelerations, center_fractions=center_fractions, + mode=mode, ) mask = mask_func(shape[1:], seed=seed) acs_mask = mask_func(shape[1:], seed=seed, return_acs=True) - expected_mask_shape = (1, shape[1], shape[2], 1) + expected_mask_shape = [1] * len(shape) + expected_mask_shape[-3:-1] = shape[-3:-1] + expected_mask_shape[-4] = expected_mask_shape[-4] if mode == MaskFuncMode.STATIC else shape[-4] assert mask.max() == 1 assert mask.min() == 0 - assert mask.shape == expected_mask_shape + assert mask.shape == tuple(expected_mask_shape) if seed is not None: assert np.allclose(mask & acs_mask, acs_mask) @pytest.mark.parametrize( - "shape, accelerations, center_fractions", + "shape, accelerations, center_fractions, mode", [ - ([4, 32, 32, 2], [4], [0.08]), - ([2, 64, 64, 2], [8, 4], [0.04, 0.08]), + ([4, 32, 32, 2], [4], [0.08], MaskFuncMode.STATIC), + ([2, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.STATIC), + ([4, 2, 32, 32, 2], [4], [0.08], MaskFuncMode.STATIC), + ([4, 2, 32, 32, 2], [4], [0.08], MaskFuncMode.DYNAMIC), + ([2, 3, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.MULTISLICE), ], ) -def test_same_across_volumes_mask_poisson(shape, accelerations, center_fractions): +def test_same_across_volumes_mask_poisson(shape, accelerations, center_fractions, mode): mask_func = VariableDensityPoissonMaskFunc( accelerations=accelerations, center_fractions=center_fractions, + mode=mode, ) - num_slices = shape[0] - masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] + batch_sz = shape[0] + masks = [mask_func(shape[1:], seed=123) for _ in range(batch_sz)] - assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(num_slices - 1)) + assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(batch_sz - 1)) @pytest.mark.parametrize( - "shape, accelerations, center_fractions", + "shape, accelerations, center_fractions, mode", [ - ([4, 32, 32, 2], [4], [0.08]), - ([2, 64, 64, 2], [8, 4], [0.04, 0.08]), + ([4, 32, 32, 2], [4], [0.08], MaskFuncMode.STATIC), + ([2, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.STATIC), + ([4, 2, 32, 32, 2], [4], [0.08], MaskFuncMode.STATIC), + ([4, 2, 32, 32, 2], [4], [0.08], MaskFuncMode.MULTISLICE), + ([2, 3, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.DYNAMIC), ], ) -def test_same_across_volumes_mask_gaussian_2d(shape, accelerations, center_fractions): - mask_func = Gaussian2DMaskFunc(accelerations=accelerations, center_fractions=center_fractions) - num_slices = shape[0] - masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] +def test_same_across_volumes_mask_gaussian_2d(shape, accelerations, center_fractions, mode): + mask_func = Gaussian2DMaskFunc(accelerations=accelerations, center_fractions=center_fractions, mode=mode) + batch_sz = shape[0] + masks = [mask_func(shape[1:], seed=123) for _ in range(batch_sz)] - assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(num_slices - 1)) + assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(batch_sz - 1)) @pytest.mark.parametrize( - "shape, accelerations, center_fractions", + "shape, accelerations, center_fractions, mode", [ - ([4, 32, 32, 2], [4], [0.08]), - ([2, 64, 64, 2], [8, 4], [0.04, 0.08]), + ([4, 32, 32, 2], [4], [0.08], MaskFuncMode.STATIC), + ([2, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.STATIC), + ([2, 3, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.DYNAMIC), + ([2, 3, 64, 64, 2], [8, 4], [0.04, 0.08], MaskFuncMode.STATIC), ], ) @pytest.mark.parametrize( @@ -393,13 +481,46 @@ def test_same_across_volumes_mask_gaussian_2d(shape, accelerations, center_fract tuple(np.random.randint(100000, 1000000, 30)), ], ) -def test_apply_mask_gaussian_2d(shape, accelerations, center_fractions, seed): - mask_func = Gaussian2DMaskFunc(accelerations=accelerations, center_fractions=center_fractions) +def test_apply_mask_gaussian_2d(shape, accelerations, center_fractions, seed, mode): + mask_func = Gaussian2DMaskFunc(accelerations=accelerations, center_fractions=center_fractions, mode=mode) mask = mask_func(shape[1:], seed=seed) acs_mask = mask_func(shape[1:], seed=seed, return_acs=True) - expected_mask_shape = (1, shape[1], shape[2], 1) + expected_mask_shape = [1] * len(shape) + expected_mask_shape[-3:-1] = shape[-3:-1] + expected_mask_shape[-4] = expected_mask_shape[-4] if mode == MaskFuncMode.STATIC else shape[-4] assert mask.max() == 1 assert mask.min() == 0 - assert mask.shape == expected_mask_shape + assert mask.shape == tuple(expected_mask_shape) if seed is not None: assert np.allclose(mask & acs_mask, acs_mask) + + +@pytest.mark.parametrize( + "shape, accelerations, center_fractions", + [ + ([2, 10, 64, 64, 2], [8, 4], [0.04, 0.08]), + ], +) +@pytest.mark.parametrize( + "mask_func", + [ + KtGaussian1DMaskFunc, + KtRadialMaskFunc, + KtUniformMaskFunc, + ], +) +def test_apply_kt_mask(mask_func, shape, accelerations, center_fractions): + mask_func = mask_func(accelerations=accelerations, center_fractions=center_fractions) + mask = mask_func(shape[1:], seed=123) + acs_mask = mask_func(shape[1:], seed=123, return_acs=True) + + expected_mask_shape = [1] * len(shape) + expected_mask_shape[-3:-1] = shape[-3:-1] + expected_mask_shape[-4] = shape[-4] + + assert mask.max() == 1 + assert mask.min() == 0 + assert mask.shape == tuple(expected_mask_shape) + + assert all(not np.allclose(mask[:, _], mask[:, _ + 1]) for _ in range(shape[1] - 1)) + assert all(np.allclose(acs_mask[:, _], acs_mask[:, _ + 1]) for _ in range(shape[1] - 1)) diff --git a/tests/tests_data/test_mri_transforms.py b/tests/tests_data/test_mri_transforms.py index 145a30eb3..af38e1f6d 100644 --- a/tests/tests_data/test_mri_transforms.py +++ b/tests/tests_data/test_mri_transforms.py @@ -56,19 +56,23 @@ def create_sample(shape, **kwargs): def _mask_func(shape, seed=None, return_acs=False): - num_rows, num_cols = shape[:2] + num_rows, num_cols = shape[-3:-1] mask = torch.zeros(num_rows, num_cols).bool() mask[ num_rows // 2 - num_rows // 4 : num_rows // 2 + num_rows // 4, num_cols // 2 - num_cols // 4 : num_cols // 2 + num_cols // 4, ] = True + mask_shape = torch.ones(len(shape)).int().tolist() + mask_shape[-3] = num_rows + mask_shape[-2] = num_cols if return_acs: - return mask.unsqueeze(0).unsqueeze(-1) + return mask.reshape(mask_shape).unsqueeze(0) if seed: rng = np.random.RandomState() rng.seed(seed) - mask = mask | torch.from_numpy(np.random.rand(num_rows, num_cols)).round().bool() - return mask.unsqueeze(0).unsqueeze(-1) + mask = mask.reshape(mask_shape) | torch.from_numpy(np.random.rand(*mask_shape)).round().bool() + + return mask.unsqueeze(0) @pytest.mark.parametrize( @@ -168,9 +172,10 @@ def test_CreateSamplingMask(shape, return_acs, use_shape): sample = create_sample(shape) transform = CreateSamplingMask( - mask_func=_mask_func, shape=shape[-3:-1] if use_shape else None, return_acs=return_acs + mask_func=_mask_func, shape=shape[1:-1] if use_shape else None, return_acs=return_acs ) sample = transform(sample) + print(sample["kspace"].shape, sample["sampling_mask"].shape) assert "sampling_mask" in sample mask_shape = torch.ones(len(shape))