From 0abc8e24a8dd1a00cddb51abb820cb8c54c2e944 Mon Sep 17 00:00:00 2001 From: Matthew Muckley Date: Tue, 27 Oct 2020 09:54:24 -0400 Subject: [PATCH] Add typing (#91) * Add typing to data folder * Add typing to models * Better comments * Add typing to pl_modules * Typing for base module, config * Test fixes * Add mypy test * Add dev requirements * Adding a few more ignores for mypy * Single install cache --- .circleci/config.yml | 7 +- dev-requirements.txt | 4 + fastmri/coil_combine.py | 16 +- fastmri/data/mri_data.py | 138 +++++++++------- fastmri/data/subsample.py | 105 +++++++----- fastmri/data/transforms.py | 240 ++++++++++++++++------------ fastmri/data/volume_sampler.py | 36 +++-- fastmri/evaluate.py | 15 +- fastmri/losses.py | 16 +- fastmri/math.py | 128 +++++++++------ fastmri/models/unet.py | 69 ++++---- fastmri/models/varnet.py | 176 ++++++++++++-------- fastmri/pl_modules/data_module.py | 80 +++++----- fastmri/pl_modules/mri_module.py | 7 +- fastmri/pl_modules/varnet_module.py | 47 +++--- fastmri/utils.py | 15 +- mypy.ini | 9 ++ 17 files changed, 627 insertions(+), 481 deletions(-) create mode 100644 dev-requirements.txt create mode 100644 mypy.ini diff --git a/.circleci/config.yml b/.circleci/config.yml index 09717283..89e0ff47 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -9,12 +9,8 @@ jobs: executor: python/default steps: - checkout - - run: - name: Preinstallation Packages - command: | - pip install wheel - pip install pytest - python/install-packages: + pip-dependency-file: dev-requirements.txt pkg-manager: pip - run: name: Install fastMRI @@ -25,6 +21,7 @@ jobs: command: | pytest --version pytest tests + mypy fastmri name: Test workflows: diff --git a/dev-requirements.txt b/dev-requirements.txt new file mode 100644 index 00000000..13da1c12 --- /dev/null +++ b/dev-requirements.txt @@ -0,0 +1,4 @@ +-r requirements.txt +wheel +pytest +mypy \ No newline at end of file diff --git a/fastmri/coil_combine.py b/fastmri/coil_combine.py index e38dd309..b29b9d21 100644 --- a/fastmri/coil_combine.py +++ b/fastmri/coil_combine.py @@ -10,33 +10,33 @@ import fastmri -def rss(data, dim=0): +def rss(data: torch.Tensor, dim: int = 0) -> torch.Tensor: """ Compute the Root Sum of Squares (RSS). RSS is computed assuming that dim is the coil dimension. Args: - data (torch.Tensor): The input tensor - dim (int): The dimensions along which to apply the RSS transform + data: The input tensor + dim: The dimensions along which to apply the RSS transform Returns: - torch.Tensor: The RSS value. + The RSS value. """ return torch.sqrt((data ** 2).sum(dim)) -def rss_complex(data, dim=0): +def rss_complex(data: torch.Tensor, dim: int = 0) -> torch.Tensor: """ Compute the Root Sum of Squares (RSS) for complex inputs. RSS is computed assuming that dim is the coil dimension. Args: - data (torch.Tensor): The input tensor - dim (int): The dimensions along which to apply the RSS transform + data: The input tensor + dim: The dimensions along which to apply the RSS transform Returns: - torch.Tensor: The RSS value. + The RSS value. """ return torch.sqrt(fastmri.complex_abs_sq(data).sum(dim)) diff --git a/fastmri/data/mri_data.py b/fastmri/data/mri_data.py index ec3415ff..4dbafbdc 100644 --- a/fastmri/data/mri_data.py +++ b/fastmri/data/mri_data.py @@ -6,10 +6,12 @@ """ import logging -import pathlib +import os import pickle import random import xml.etree.ElementTree as etree +from pathlib import Path +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union from warnings import warn import h5py @@ -18,7 +20,11 @@ import yaml -def et_query(root, qlist, namespace="http://www.ismrm.org/ISMRMRD"): +def et_query( + root: etree.Element, + qlist: Sequence[str], + namespace: str = "http://www.ismrm.org/ISMRMRD", +) -> str: """ ElementTree query function. @@ -26,12 +32,13 @@ def et_query(root, qlist, namespace="http://www.ismrm.org/ISMRMRD"): for nexted queries. Args: - root (xml.etree.ElementTree.Element): Root of the xml. - qlist (Sequence): A list of strings for nested searches. - namespace (str): xml namespace. + root: Root of the xml to search through. + qlist: A list of strings for nested searches, e.g. ["Encoding", + "matrixSize"] + namespace: Optional; xml namespace to prepend query. Returns: - str: The retrieved data. + The retrieved data as a string. """ s = "." prefix = "ismrmrd_namespace" @@ -41,10 +48,16 @@ def et_query(root, qlist, namespace="http://www.ismrm.org/ISMRMRD"): for el in qlist: s = s + f"//{prefix}:{el}" - return root.find(s, ns).text + value = root.find(s, ns) + if value is None: + raise RuntimeError("Element not found") + return str(value.text) -def fetch_dir(key, data_config_file=pathlib.Path("fastmri_dirs.yaml")): + +def fetch_dir( + key: str, data_config_file: Union[str, Path, os.PathLike] = "fastmri_dirs.yaml" +) -> Path: """ Data directory fetcher. @@ -53,14 +66,15 @@ def fetch_dir(key, data_config_file=pathlib.Path("fastmri_dirs.yaml")): and this function will retrieve the requested subsplit of the data for use. Args: - key (str): key to retrieve path from data_config_file. - data_config_file (pathlib.Path, - default=pathlib.Path("fastmri_dirs.yaml")): Default path config - file. + key: key to retrieve path from data_config_file. Expected to be in + ("knee_path", "brain_path", "log_path"). + data_config_file: Optional; Default path config file to fetch path + from. Returns: - pathlib.Path: The path to the specified directory. + The path to the specified directory. """ + data_config_file = Path(data_config_file) if not data_config_file.is_file(): default_config = { "knee_path": "/path/to/knee", @@ -81,37 +95,44 @@ def fetch_dir(key, data_config_file=pathlib.Path("fastmri_dirs.yaml")): with open(data_config_file, "r") as f: data_dir = yaml.safe_load(f)[key] - data_dir = pathlib.Path(data_dir) - - return data_dir + return Path(data_dir) class CombinedSliceDataset(torch.utils.data.Dataset): """ A container for combining slice datasets. - - Args: - roots (list of pathlib.Path): Paths to the datasets. - transforms (list of callable): A callable object that pre-processes the - raw data into appropriate form. The transform function should take - 'kspace', 'target', 'attributes', 'filename', and 'slice' as - inputs. 'target' may be null for test data. - challenges (list of str): "singlecoil" or "multicoil" depending on which - challenge to use. - sample_rates (list of float, optional): A float between 0 and 1. This - controls what fraction of the volumes should be loaded. - num_cols (tuple(int), optional): if provided, only slices with the desired - number of columns will be considered. """ - def __init__(self, roots, transforms, challenges, sample_rates=None, num_cols=None): + def __init__( + self, + roots: Sequence[Path], + transforms: Sequence[Callable], + challenges: Sequence[str], + sample_rates: Optional[Sequence[float]] = None, + num_cols: Optional[Tuple[int]] = None, + ): + """ + Args: + roots: Paths to the datasets. + transforms: A callable object that preprocesses the raw data into + appropriate form. The transform function should take 'kspace', + 'target', 'attributes', 'filename', and 'slice' as inputs. + 'target' may be null for test data. + challenges: "singlecoil" or "multicoil" depending on which + challenge to use. + sample_rates: Optional; A float between 0 and 1. This controls + what fraction of the volumes should be loaded. + num_cols: Optional; If provided, only slices with the desired + number of columns will be considered. + """ assert len(roots) == len(transforms) == len(challenges) if sample_rates is not None: assert len(sample_rates) == len(roots) else: sample_rates = [1] * len(roots) - self.datasets = list() + self.datasets = [] + self.examples: List[Tuple[Path, int, Dict[str, object]]] = [] for i in range(len(roots)): self.datasets.append( SliceDataset( @@ -123,6 +144,8 @@ def __init__(self, roots, transforms, challenges, sample_rates=None, num_cols=No ) ) + self.examples = self.examples + self.datasets[-1].examples + def __len__(self): length = 0 for dataset in self.datasets: @@ -141,36 +164,37 @@ def __getitem__(self, i): class SliceDataset(torch.utils.data.Dataset): """ A PyTorch Dataset that provides access to MR image slices. - - Args: - root (pathlib.Path): Path to the dataset. - transform (callable): A callable object that pre-processes the raw data - into appropriate form. The transform function should take 'kspace', - 'target', 'attributes', 'filename', and 'slice' as inputs. 'target' - may be null for test data. - challenge (str): "singlecoil" or "multicoil" depending on which - challenge to use. - sample_rate (float, optional): A float between 0 and 1. This controls - what fraction of the volumes should be loaded. - dataset_cache_file (pathlib.Path). A file in which to cache dataset - information for faster load times. Default: dataset_cache.pkl. - num_cols (tuple(int), optional): if provided, only slices with the desired - number of columns will be considered. """ def __init__( self, - root, - transform, - challenge, - sample_rate=1, - dataset_cache_file=pathlib.Path("dataset_cache.pkl"), - num_cols=None, + root: Union[str, Path, os.PathLike], + transform: Callable, + challenge: str, + sample_rate: float = 1.0, + dataset_cache_file: Union[str, Path, os.PathLike] = "dataset_cache.pkl", + num_cols: Optional[Tuple[int]] = None, ): + """ + Args: + root: Path to the dataset. + transform: A callable object that pre-processes the raw data into + appropriate form. The transform function should take 'kspace', + 'target', 'attributes', 'filename', and 'slice' as inputs. + 'target' may be null for test data. + challenge: "singlecoil" or "multicoil" depending on which challenge + to use. + sample_rate: Optional; A float between 0 and 1. This controls what + fraction of the volumes should be loaded. Defaults to 1.0. + dataset_cache_file: Optional; A file in which to cache dataset + information for faster load times. + num_cols: Optional; If provided, only slices with the desired + number of columns will be considered. + """ if challenge not in ("singlecoil", "multicoil"): raise ValueError('challenge should be either "singlecoil" or "multicoil"') - self.dataset_cache_file = dataset_cache_file + self.dataset_cache_file = Path(dataset_cache_file) self.transform = transform self.recons_key = ( @@ -185,7 +209,7 @@ def __init__( dataset_cache = {} if dataset_cache.get(root) is None: - files = list(pathlib.Path(root).iterdir()) + files = list(Path(root).iterdir()) for fname in sorted(files): with h5py.File(fname, "r") as hf: et_root = etree.fromstring(hf["ismrmrd_header"][()]) @@ -238,13 +262,15 @@ def __init__( if num_cols: self.examples = [ - ex for ex in self.examples if ex[2]["encoding_size"][1] in num_cols + ex + for ex in self.examples + if ex[2]["encoding_size"][1] in num_cols # type: ignore ] def __len__(self): return len(self.examples) - def __getitem__(self, i): + def __getitem__(self, i: int): fname, dataslice, metadata = self.examples[i] with h5py.File(fname, "r") as hf: diff --git a/fastmri/data/subsample.py b/fastmri/data/subsample.py index d0620da3..4115aee4 100644 --- a/fastmri/data/subsample.py +++ b/fastmri/data/subsample.py @@ -6,31 +6,29 @@ """ import contextlib +from typing import Optional, Sequence, Tuple, Union import numpy as np import torch @contextlib.contextmanager -def temp_seed(rng, seed): - state = rng.get_state() - rng.seed(seed) - try: - yield - finally: - rng.set_state(state) - - -def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): - if mask_type_str == "random": - return RandomMaskFunc(center_fractions, accelerations) - elif mask_type_str == "equispaced": - return EquispacedMaskFunc(center_fractions, accelerations) +def temp_seed(rng: np.random, seed: Optional[Union[int, Tuple[int, ...]]]): + if seed is None: + try: + yield + finally: + pass else: - raise Exception(f"{mask_type_str} not supported") + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) -class MaskFunc(object): +class MaskFunc: """ An object for GRAPPA-style sampling masks. @@ -38,15 +36,15 @@ class MaskFunc(object): subsampling outer k-space regions based on the undersampling factor. """ - def __init__(self, center_fractions, accelerations): + def __init__(self, center_fractions: Sequence[float], accelerations: Sequence[int]): """ Args: - center_fractions (List[float]): Fraction of low-frequency columns to be - retained. If multiple values are provided, then one of these - numbers is chosen uniformly each time. - accelerations (List[int]): Amount of under-sampling. This should have - the same length as center_fractions. If multiple values are - provided, then one of these is chosen uniformly each time. + center_fractions: Fraction of low-frequency columns to be retained. + If multiple values are provided, then one of these numbers is + chosen uniformly each time. + accelerations: Amount of under-sampling. This should have the same + length as center_fractions. If multiple values are provided, + then one of these is chosen uniformly each time. """ if len(center_fractions) != len(accelerations): raise ValueError( @@ -57,6 +55,11 @@ def __init__(self, center_fractions, accelerations): self.accelerations = accelerations self.rng = np.random + def __call__( + self, shape: Sequence[int], seed: Optional[Union[int, Tuple[int, ...]]] = None + ) -> torch.Tensor: + raise NotImplementedError + def choose_acceleration(self): """Choose acceleration based on class parameters.""" choice = self.rng.randint(0, len(self.accelerations)) @@ -89,20 +92,22 @@ class RandomMaskFunc(MaskFunc): center fraction is selected. """ - def __call__(self, shape, seed=None): + def __call__( + self, shape: Sequence[int], seed: Optional[Union[int, Tuple[int, ...]]] = None + ) -> torch.Tensor: """ Create the mask. Args: - shape (iterable[int]): The shape of the mask to be created. The - shape should have at least 3 dimensions. Samples are drawn - along the second last dimension. - seed (int, optional): Seed for the random number generator. Setting - the seed ensures the same mask is generated each time for the - same shape. The random state is reset afterwards. - + shape: The shape of the mask to be created. The shape should have + at least 3 dimensions. Samples are drawn along the second last + dimension. + seed: Seed for the random number generator. Setting the seed + ensures the same mask is generated each time for the same + shape. The random state is reset afterwards. + Returns: - torch.Tensor: A mask of the specified shape. + A mask of the specified shape. """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") @@ -151,18 +156,20 @@ class EquispacedMaskFunc(MaskFunc): the function has been preserved to match the public multicoil data. """ - def __call__(self, shape, seed): + def __call__( + self, shape: Sequence[int], seed: Optional[Union[int, Tuple[int, ...]]] = None + ) -> torch.Tensor: """ Args: - shape (iterable[int]): The shape of the mask to be created. The - shape should have at least 3 dimensions. Samples are drawn - along the second last dimension. - seed (int, optional): Seed for the random number generator. Setting - the seed ensures the same mask is generated each time for the - same shape. The random state is reset afterwards. + shape: The shape of the mask to be created. The shape should have + at least 3 dimensions. Samples are drawn along the second last + dimension. + seed: Seed for the random number generator. Setting the seed + ensures the same mask is generated each time for the same + shape. The random state is reset afterwards. Returns: - torch.Tensor: A mask of the specified shape. + A mask of the specified shape. """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") @@ -193,3 +200,21 @@ def __call__(self, shape, seed): mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) return mask + + +def create_mask_for_mask_type( + mask_type_str: str, center_fractions: Sequence[float], accelerations: Sequence[int], +) -> MaskFunc: + """ + Creates a mask of the specified type. + + Args: + center_fractions: What fraction of the center of k-space to include. + accelerations: What accelerations to apply. + """ + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") diff --git a/fastmri/data/transforms.py b/fastmri/data/transforms.py index 7a4d6911..0252c015 100644 --- a/fastmri/data/transforms.py +++ b/fastmri/data/transforms.py @@ -5,12 +5,16 @@ LICENSE file in the root directory of this source tree. """ +from typing import Dict, Optional, Sequence, Tuple, Union + import fastmri import numpy as np import torch +from .subsample import MaskFunc + -def to_tensor(data): +def to_tensor(data: np.ndarray) -> torch.Tensor: """ Convert numpy array to PyTorch tensor. @@ -18,10 +22,10 @@ def to_tensor(data): dimension. Args: - data (np.array): Input numpy array. + data: Input numpy array. Returns: - torch.Tensor: PyTorch version of data. + PyTorch version of data. """ if np.iscomplexobj(data): data = np.stack((data.real, data.imag), axis=-1) @@ -29,41 +33,43 @@ def to_tensor(data): return torch.from_numpy(data) -def tensor_to_complex_np(data): +def tensor_to_complex_np(data: torch.Tensor) -> np.ndarray: """ Converts a complex torch tensor to numpy array. Args: - data (torch.Tensor): Input data to be converted to numpy. + data: Input data to be converted to numpy. Returns: - np.array: Complex numpy version of data. + Complex numpy version of data. """ data = data.numpy() return data[..., 0] + 1j * data[..., 1] -def apply_mask(data, mask_func, seed=None, padding=None): +def apply_mask( + data: torch.Tensor, + mask_func: MaskFunc, + seed: Optional[Union[int, Tuple[int, ...]]] = None, + padding: Optional[Sequence[int]] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: """ Subsample given k-space by multiplying with a mask. Args: - data (torch.Tensor): The input k-space data. This should have at - least 3 dimensions, where dimensions -3 and -2 are the spatial - dimensions, and the final dimension has size 2 (for complex - values). - mask_func (Callable): A function that takes a shape (tuple of ints) - and a random number seed and returns a mask. - seed (int or 1-d array_like, optional): Seed for the random number - generator. Defaults to None. - padding (tuple, optional): Padding value to apply for mask. Defaults to - None. + data: The input k-space data. This should have at least 3 dimensions, + where dimensions -3 and -2 are the spatial dimensions, and the + final dimension has size 2 (for complex values). + mask_func: A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed: Seed for the random number generator. + padding: Padding value to apply for mask. Returns: - (tuple): tuple containing: - masked data (torch.Tensor): Subsampled k-space data - mask (torch.Tensor): The generated mask + tuple containing: + masked data: Subsampled k-space data + mask: The generated mask """ shape = np.array(data.shape) shape[:-3] = 1 @@ -77,26 +83,36 @@ def apply_mask(data, mask_func, seed=None, padding=None): return masked_data, mask -def mask_center(x, mask_from, mask_to): +def mask_center(x: torch.Tensor, mask_from: int, mask_to: int) -> torch.Tensor: + """ + Initializes a mask with the center filled in. + + Args: + mask_from: Part of center to start filling. + mask_to: Part of center to end filling. + + Returns: + A mask with the center filled. + """ mask = torch.zeros_like(x) mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to] return mask -def center_crop(data, shape): +def center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor: """ Apply a center crop to the input real image or batch of real images. Args: - data (torch.Tensor): The input tensor to be center cropped. It should + data: The input tensor to be center cropped. It should have at least 2 dimensions and the cropping is applied along the last two dimensions. - shape (int, int): The output shape. The shape should be smaller than - the corresponding dimensions of data. + shape: The output shape. The shape should be smaller + than the corresponding dimensions of data. Returns: - torch.Tensor: The center cropped image. + The center cropped image. """ assert 0 < shape[0] <= data.shape[-2] assert 0 < shape[1] <= data.shape[-1] @@ -109,20 +125,19 @@ def center_crop(data, shape): return data[..., w_from:w_to, h_from:h_to] -def complex_center_crop(data, shape): +def complex_center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor: """ Apply a center crop to the input image or batch of complex images. Args: - data (torch.Tensor): The complex input tensor to be center cropped. It - should have at least 3 dimensions and the cropping is applied along - dimensions -3 and -2 and the last dimensions should have a size of - 2. - shape (int): The output shape. The shape should be smaller than - the corresponding dimensions of data. + data: The complex input tensor to be center cropped. It should have at + least 3 dimensions and the cropping is applied along dimensions -3 + and -2 and the last dimensions should have a size of 2. + shape: The output shape. The shape should be smaller than the + corresponding dimensions of data. Returns: - torch.Tensor: The center cropped image + The center cropped image """ assert 0 < shape[0] <= data.shape[-3] assert 0 < shape[1] <= data.shape[-2] @@ -135,7 +150,9 @@ def complex_center_crop(data, shape): return data[..., w_from:w_to, h_from:h_to, :] -def center_crop_to_smallest(x, y): +def center_crop_to_smallest( + x: torch.Tensor, y: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply a center crop on the larger image to the size of the smaller. @@ -144,11 +161,11 @@ def center_crop_to_smallest(x, y): be a mixture of the two. Args: - x (torch.Tensor): The first image. - y (torch.Tensor): The second image + x: The first image. + y: The second image. Returns: - tuple: tuple of tensors x and y, each cropped to the minimim size. + tuple of tensors x and y, each cropped to the minimim size. """ smallest_width = min(x.shape[-1], y.shape[-1]) smallest_height = min(x.shape[-2], y.shape[-2]) @@ -158,26 +175,32 @@ def center_crop_to_smallest(x, y): return x, y -def normalize(data, mean, stddev, eps=0.0): +def normalize( + data: torch.Tensor, + mean: Union[float, torch.Tensor], + stddev: Union[float, torch.Tensor], + eps: Union[float, torch.Tensor] = 0.0, +) -> torch.Tensor: """ Normalize the given tensor. Applies the formula (data - mean) / (stddev + eps). Args: - data (torch.Tensor): Input data to be normalized. - mean (float): Mean value. - stddev (float): Standard deviation. - eps (float, optional): Added to stddev to prevent dividing by zero. - Defaults to 0.0. + data: Input data to be normalized. + mean: Mean value. + stddev: Standard deviation. + eps: Added to stddev to prevent dividing by zero. Returns: - torch.Tensor: Normalized tensor + Normalized tensor. """ return (data - mean) / (stddev + eps) -def normalize_instance(data, eps=0.0): +def normalize_instance( + data: torch.Tensor, eps: Union[float, torch.Tensor] = 0.0 +) -> Tuple[torch.Tensor, Union[torch.Tensor], Union[torch.Tensor]]: """ Normalize the given tensor with instance norm/ @@ -185,9 +208,8 @@ def normalize_instance(data, eps=0.0): are computed from the data itself. Args: - data (torch.Tensor): Input data to be normalized - eps (float, optional): Added to stddev to prevent dividing by zero. - Defaults to 0.0. + data: Input data to be normalized + eps: Added to stddev to prevent dividing by zero. Returns: torch.Tensor: Normalized tensor @@ -203,17 +225,20 @@ class UnetDataTransform: Data Transformer for training U-Net models. """ - def __init__(self, which_challenge, mask_func=None, use_seed=True): + def __init__( + self, + which_challenge: str, + mask_func: Optional[MaskFunc] = None, + use_seed: bool = True, + ): """ Args: - which_challenge (str): Either "singlecoil" or "multicoil" denoting - the dataset. - mask_func (fastmri.data.subsample.MaskFunc, optional): A function - that can create a mask of appropriate shape. Defaults to None. - use_seed (bool, optional): If true, this class computes a pseudo - random number generator seed from the filename. This ensures - that the same mask is used for all the slices of a given volume - every time. Defaults to True. + which_challenge: Challenge from ("singlecoil", "multicoil"). + mask_func: Optional; A function that can create a mask of + appropriate shape. + use_seed: If true, this class computes a pseudo random number + generator seed from the filename. This ensures that the same + mask is used for all the slices of a given volume every time. """ if which_challenge not in ("singlecoil", "multicoil"): raise ValueError("Challenge should either be 'singlecoil' or 'multicoil'") @@ -222,28 +247,33 @@ def __init__(self, which_challenge, mask_func=None, use_seed=True): self.which_challenge = which_challenge self.use_seed = use_seed - def __call__(self, kspace, mask, target, attrs, fname, slice_num): + def __call__( + self, + kspace: np.ndarray, + mask: torch.Tensor, + target: np.ndarray, + attrs: Dict, + fname: str, + slice_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, int, float]: """ Args: - kspace (numpy.array): Input k-space of shape (num_coils, rows, - cols, 2) for multi-coil data or (rows, cols, 2) for single coil - data. - mask (numpy.array): Mask from the test dataset. - target (numpy.array): Target image. - attrs (dict): Acquisition related information stored in the HDF5 - object. - fname (str): File name. - slice_num (int): Serial number of the slice. + kspace: Input k-space of shape (num_coils, rows, cols, 2) for + multi-coil data or (rows, cols, 2) for single coil data. + mask: Mask from the test dataset. + target: Target image. + attrs: Acquisition related information stored in the HDF5 object. + fname: File name. + slice_num: Serial number of the slice. Returns: - (tuple): tuple containing: - image (torch.Tensor): Zero-filled input image. - target (torch.Tensor): Target image converted to a torch - Tensor. - mean (float): Mean value used for normalization. - std (float): Standard deviation value used for normalization. - fname (str): File name. - slice_num (int): Serial number of the slice. + tuple containing: + image: Zero-filled input image. + target: Target image converted to a torch.Tensor. + mean: Mean value used for normalization. + std: Standard deviation value used for normalization. + fname: File name. + slice_num: Serial number of the slice. """ kspace = to_tensor(kspace) @@ -300,42 +330,46 @@ class VarNetDataTransform: Data Transformer for training VarNet models. """ - def __init__(self, mask_func=None, use_seed=True): + def __init__(self, mask_func: Optional[MaskFunc] = None, use_seed: bool = True): """ Args: - mask_func (fastmri.data.subsample.MaskFunc, optional): A function - that can create a mask of appropriate shape. Defaults to None. - use_seed (bool, optional): If true, this class computes a pseudo - random number generator seed from the filename. This ensures - that the same mask is used for all the slices of a given volume - every time. Defaults to True. + mask_func: Optional; A function that can create a mask of + appropriate shape. Defaults to None. + use_seed: If True, this class computes a pseudo random number + generator seed from the filename. This ensures that the same + mask is used for all the slices of a given volume every time. """ self.mask_func = mask_func self.use_seed = use_seed - def __call__(self, kspace, mask, target, attrs, fname, slice_num): + def __call__( + self, + kspace: np.ndarray, + mask: torch.Tensor, + target: torch.Tensor, + attrs: Dict, + fname: str, + slice_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, str, int, float, torch.Tensor]: """ Args: - kspace (numpy.array): Input k-space of shape (num_coils, rows, - cols, 2) for multi-coil data or (rows, cols, 2) for single coil - data. - mask (numpy.array): Mask from the test dataset. - target (numpy.array): Target image. - attrs (dict): Acquisition related information stored in the HDF5 - object. - fname (str): File name. - slice_num (int): Serial number of the slice. + kspace: Input k-space of shape (num_coils, rows, cols, 2) for + multi-coil data. + mask: Mask from the test dataset. + target: Target image. + attrs: Acquisition related information stored in the HDF5 object. + fname: File name. + slice_num: Serial number of the slice. Returns: - (tuple): tuple containing: - masked_kspace (torch.Tensor): k-space after applying sampling - mask. - mask (torch.Tensor): The applied sampling mask - target (torch.Tensor): The target image (if applicable). - fname (str): File name. - slice_num (int): The slice index. - max_value (float): Maximum image value. - crop_size (torch.Tensor): the size to crop the final image. + tuple containing: + masked_kspace: k-space after applying sampling mask. + mask: The applied sampling mask + target: The target image (if applicable). + fname: File name. + slice_num: The slice index. + max_value: Maximum image value. + crop_size: The size to crop the final image. """ if target is not None: target = to_tensor(target) @@ -362,7 +396,7 @@ def __call__(self, kspace, mask, target, attrs, fname, slice_num): shape[:-3] = 1 mask_shape = [1 for _ in shape] mask_shape[-2] = num_cols - mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + mask = mask.reshape(*mask_shape) mask[:, :, :acq_start] = 0 mask[:, :, acq_end:] = 0 diff --git a/fastmri/data/volume_sampler.py b/fastmri/data/volume_sampler.py index 5f2e4840..0de9806e 100644 --- a/fastmri/data/volume_sampler.py +++ b/fastmri/data/volume_sampler.py @@ -6,10 +6,12 @@ """ import math +from typing import List, Optional, Union import numpy as np import torch import torch.distributed as dist +from fastmri.data.mri_data import CombinedSliceDataset, SliceDataset from torch.utils.data import Sampler @@ -23,21 +25,27 @@ class VolumeSampler(Sampler): fname is essentially the volume name (actually a filename). """ - def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0): + def __init__( + self, + dataset: Union[CombinedSliceDataset, SliceDataset], + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + ): """ Args: - dataset (torch.utils.data.Dataset): An MRI dataset (e.g., SliceData). - num_replicas (int, optional): Number of processes participating in - distributed training. By default, :attr:`rank` is retrieved - from the current distributed group. - rank (int, optional): Rank of the current process within - :attr:`num_replicas`. By default, :attr:`rank` is retrieved - from the current distributed group. - shuffle (bool, optional): If ``True`` (default), sampler will - shuffle the indices. - seed (int, optional): random seed used to shuffle the sampler if + dataset: An MRI dataset (e.g., SliceData). + num_replicas: Number of processes participating in distributed + training. By default, :attr:`rank` is retrieved from the + current distributed group. + rank: Rank of the current process within :attr:`num_replicas`. By + default, :attr:`rank` is retrieved from the current distributed + group. + shuffle: If ``True`` (default), sampler will shuffle the indices. + seed: random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be identical across - all processes in the distributed group. Default: ``0``. + all processes in the distributed group. """ if num_replicas is None: if not dist.is_available(): @@ -59,14 +67,14 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0): # get all file names and split them based on number of processes self.all_volume_names = np.array( - sorted(list({example[0] for example in self.dataset.examples})) + sorted([example[0] for example in self.dataset.examples]) ) self.all_volumes_split = np.array_split( self.all_volume_names, self.num_replicas ) # get slice indices for each file name - indices = [list() for _ in range(self.num_replicas)] + indices: List[List[int]] = [[] for _ in range(self.num_replicas)] for i, example in enumerate(self.dataset.examples): vname = example[0] diff --git a/fastmri/evaluate.py b/fastmri/evaluate.py index 1280d531..dfdb8f1a 100644 --- a/fastmri/evaluate.py +++ b/fastmri/evaluate.py @@ -8,6 +8,7 @@ import argparse import pathlib from argparse import ArgumentParser +from typing import Optional import h5py import numpy as np @@ -17,22 +18,22 @@ from fastmri.data import transforms -def mse(gt, pred): +def mse(gt: np.ndarray, pred: np.ndarray) -> np.ndarray: """Compute Mean Squared Error (MSE)""" return np.mean((gt - pred) ** 2) -def nmse(gt, pred): +def nmse(gt: np.ndarray, pred: np.ndarray) -> np.ndarray: """Compute Normalized Mean Squared Error (NMSE)""" return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 -def psnr(gt, pred): +def psnr(gt: np.ndarray, pred: np.ndarray) -> np.ndarray: """Compute Peak Signal to Noise Ratio metric (PSNR)""" return peak_signal_noise_ratio(gt, pred, data_range=gt.max()) -def ssim(gt, pred, maxval=None): +def ssim(gt: np.ndarray, pred: np.ndarray, maxval: Optional[float]) -> np.ndarray: """Compute Structural Similarity Index Metric (SSIM)""" maxval = gt.max() if maxval is None else maxval @@ -42,15 +43,13 @@ def ssim(gt, pred, maxval=None): gt[slice_num], pred[slice_num], data_range=maxval ) - ssim = ssim / gt.shape[0] - - return ssim + return ssim / gt.shape[0] METRIC_FUNCS = dict(MSE=mse, NMSE=nmse, PSNR=psnr, SSIM=ssim,) -class Metrics(object): +class Metrics: """ Maintains running statistics for a given collection of metrics. """ diff --git a/fastmri/losses.py b/fastmri/losses.py index 3c6c3463..981efcee 100644 --- a/fastmri/losses.py +++ b/fastmri/losses.py @@ -15,12 +15,12 @@ class SSIMLoss(nn.Module): SSIM loss module. """ - def __init__(self, win_size=7, k1=0.01, k2=0.03): + def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03): """ Args: - win_size (int, default=7): Window size for SSIM calculation. - k1 (float, default=0.1): k1 parameter for SSIM calculation. - k2 (float, default=0.03): k2 parameter for SSIM calculation. + win_size: Window size for SSIM calculation. + k1: k1 parameter for SSIM calculation. + k2: k2 parameter for SSIM calculation. """ super().__init__() self.win_size = win_size @@ -29,12 +29,14 @@ def __init__(self, win_size=7, k1=0.01, k2=0.03): NP = win_size ** 2 self.cov_norm = NP / (NP - 1) - def forward(self, X, Y, data_range): + def forward(self, X: torch.Tensor, Y: torch.Tensor, data_range: torch.Tensor): + assert isinstance(self.w, torch.Tensor) + data_range = data_range[:, None, None, None] C1 = (self.k1 * data_range) ** 2 C2 = (self.k2 * data_range) ** 2 - ux = F.conv2d(X, self.w) - uy = F.conv2d(Y, self.w) + ux = F.conv2d(X, self.w) # typing: ignore + uy = F.conv2d(Y, self.w) # uxx = F.conv2d(X * X, self.w) uyy = F.conv2d(Y * Y, self.w) uxy = F.conv2d(X * Y, self.w) diff --git a/fastmri/math.py b/fastmri/math.py index 074ae07d..6e4394b9 100644 --- a/fastmri/math.py +++ b/fastmri/math.py @@ -5,10 +5,13 @@ LICENSE file in the root directory of this source tree. """ +from typing import List, Optional, Tuple, Union + +import numpy as np import torch -def complex_mul(x, y): +def complex_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Complex multiplication. @@ -16,20 +19,22 @@ def complex_mul(x, y): real arrays with the last dimension being the complex dimension. Args: - x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. - y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + x: A PyTorch tensor with the last dimension of size 2. + y: A PyTorch tensor with the last dimension of size 2. Returns: - torch.Tensor: A PyTorch tensor with the last dimension of size 2. + A PyTorch tensor with the last dimension of size 2. """ - assert x.shape[-1] == y.shape[-1] == 2 + if not (x.shape[-1] == y.shape[-1] == 2): + raise ValueError("Tensors do not have separate complex dim.") + re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] return torch.stack((re, im), dim=-1) -def complex_conj(x): +def complex_conj(x: torch.Tensor) -> torch.Tensor: """ Complex conjugate. @@ -37,31 +42,33 @@ def complex_conj(x): last dimension as the complex dimension. Args: - x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. - y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + x: A PyTorch tensor with the last dimension of size 2. + y: A PyTorch tensor with the last dimension of size 2. Returns: - torch.Tensor: A PyTorch tensor with the last dimension of size 2. + A PyTorch tensor with the last dimension of size 2. """ - assert x.shape[-1] == 2 + if not x.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") return torch.stack((x[..., 0], -x[..., 1]), dim=-1) -def fft2c(data): +def fft2c(data: torch.Tensor) -> torch.Tensor: """ Apply centered 2 dimensional Fast Fourier Transform. Args: - data (torch.Tensor): Complex valued input data containing at least 3 - dimensions: dimensions -3 & -2 are spatial dimensions and dimension - -1 has size 2. All other dimensions are assumed to be batch - dimensions. + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. Returns: - torch.Tensor: The FFT of the input. + The FFT of the input. """ - assert data.size(-1) == 2 + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + data = ifftshift(data, dim=(-3, -2)) data = torch.fft(data, 2, normalized=True) data = fftshift(data, dim=(-3, -2)) @@ -69,20 +76,21 @@ def fft2c(data): return data -def ifft2c(data): +def ifft2c(data: torch.Tensor) -> torch.Tensor: """ Apply centered 2-dimensional Inverse Fast Fourier Transform. Args: - data (torch.Tensor): Complex valued input data containing at least 3 - dimensions: dimensions -3 & -2 are spatial dimensions and dimension - -1 has size 2. All other dimensions are assumed to be batch - dimensions. + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. Returns: - torch.Tensor: The IFFT of the input. + The IFFT of the input. """ - assert data.size(-1) == 2 + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + data = ifftshift(data, dim=(-3, -2)) data = torch.ifft(data, 2, normalized=True) data = fftshift(data, dim=(-3, -2)) @@ -90,75 +98,91 @@ def ifft2c(data): return data -def complex_abs(data): +def complex_abs(data: torch.Tensor) -> torch.Tensor: """ Compute the absolute value of a complex valued input tensor. Args: - data (torch.Tensor): A complex valued tensor, where the size of the - final dimension should be 2. + data: A complex valued tensor, where the size of the final dimension + should be 2. Returns: - torch.Tensor: Absolute value of data. + Absolute value of data. """ - assert data.size(-1) == 2 + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") return (data ** 2).sum(dim=-1).sqrt() -def complex_abs_sq(data): +def complex_abs_sq(data: torch.Tensor) -> torch.Tensor: """ Compute the squared absolute value of a complex tensor. Args: - data (torch.Tensor): A complex valued tensor, where the size of the - final dimension should be 2. + data: A complex valued tensor, where the size of the final dimension + should be 2. Returns: - torch.Tensor: Squared absolute value of data. + Squared absolute value of data. """ - assert data.size(-1) == 2 + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + return (data ** 2).sum(dim=-1) # Helper functions -def roll(x, shift, dim): +def roll( + x: torch.Tensor, + shift: Union[int, Tuple[int, ...], List[int]], + dim: Union[int, Tuple[int, ...], List[int]], +) -> torch.Tensor: """ Similar to np.roll but applies to PyTorch Tensors. Args: - x (torch.Tensor): A PyTorch tensor. - shift (int): Amount to roll. - dim (int): Which dimension to roll. + x: A PyTorch tensor. + shift: Amount to roll. + dim: Which dimension to roll. Returns: - torch.Tensor: Rolled version of x. + Rolled version of x. """ if isinstance(shift, (tuple, list)): + if not isinstance(dim, (tuple, list)): + raise ValueError("Passed Sequence for shift but not for dim.") assert len(shift) == len(dim) for s, d in zip(shift, dim): x = roll(x, s, d) return x + elif isinstance(dim, (tuple, list)): + raise ValueError("Passed Sequence for dim but not for shift.") + shift = shift % x.size(dim) if shift == 0: return x + left = x.narrow(dim, 0, x.size(dim) - shift) right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) -def fftshift(x, dim=None): +def fftshift( + x: torch.Tensor, dim: Optional[Union[Tuple[int, ...], List[int]]] = None +) -> torch.Tensor: """ Similar to np.fft.fftshift but applies to PyTorch Tensors Args: - x (torch.Tensor): A PyTorch tensor. - dim (int): Which dimension to fftshift. + x: A PyTorch tensor. + dim: Which dimension to fftshift. Returns: - torch.Tensor: fftshifted version of x. + fftshifted version of x. """ if dim is None: dim = tuple(range(x.dim())) @@ -171,16 +195,18 @@ def fftshift(x, dim=None): return roll(x, shift, dim) -def ifftshift(x, dim=None): +def ifftshift( + x: torch.Tensor, dim: Optional[Union[Tuple[int, ...], List[int]]] = None +) -> torch.Tensor: """ Similar to np.fft.ifftshift but applies to PyTorch Tensors Args: - x (torch.Tensor): A PyTorch tensor. - dim (int): Which dimension to ifftshift. + x: A PyTorch tensor. + dim: Which dimension to ifftshift. Returns: - torch.Tensor: ifftshifted version of x. + ifftshifted version of x. """ if dim is None: dim = tuple(range(x.dim())) @@ -193,14 +219,16 @@ def ifftshift(x, dim=None): return roll(x, shift, dim) -def tensor_to_complex_np(data): +def tensor_to_complex_np(data: torch.Tensor) -> np.ndarray: """ Converts a complex torch tensor to numpy array. + Args: - data (torch.Tensor): Input data to be converted to numpy. + data: Input data to be converted to numpy. Returns: - np.array: Complex numpy version of data + Complex numpy version of data. """ data = data.numpy() + return data[..., 0] + 1j * data[..., 1] diff --git a/fastmri/models/unet.py b/fastmri/models/unet.py index 4072c192..ea2837eb 100644 --- a/fastmri/models/unet.py +++ b/fastmri/models/unet.py @@ -20,17 +20,21 @@ class Unet(nn.Module): Springer, 2015. """ - def __init__(self, in_chans, out_chans, chans=32, num_pool_layers=4, drop_prob=0.0): + def __init__( + self, + in_chans: int, + out_chans: int, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): """ Args: - in_chans (int): Number of channels in the input to the U-Net model. - out_chans (int): Number of channels in the output to the U-Net - model. - chans (int): Number of output channels of the first convolution - layer. - num_pool_layers (int): Number of down-sampling and up-sampling - layers. - drop_prob (float): Dropout probability. + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. """ super().__init__() @@ -62,15 +66,13 @@ def __init__(self, in_chans, out_chans, chans=32, num_pool_layers=4, drop_prob=0 ) ] - def forward(self, image): + def forward(self, image: torch.Tensor) -> torch.Tensor: """ Args: - image (torch.Tensor): Input tensor of shape [batch_size, - self.in_chans, height, width] + image: Input 4D tensor of shape `(N, in_chans, H, W)`. Returns: - (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, - height, width] + Output tensor of shape `(N, out_chans, H, W)`. """ stack = [] output = image @@ -109,12 +111,12 @@ class ConvBlock(nn.Module): instance normalization, LeakyReLU activation and dropout. """ - def __init__(self, in_chans, out_chans, drop_prob): + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): """ Args: - in_chans (int): Number of channels in the input. - out_chans (int): Number of channels in the output. - drop_prob (float): Dropout probability. + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. """ super().__init__() @@ -133,24 +135,16 @@ def __init__(self, in_chans, out_chans, drop_prob): nn.Dropout2d(drop_prob), ) - def forward(self, image): + def forward(self, image: torch.Tensor) -> torch.Tensor: """ Args: - image (torch.Tensor): Input tensor of shape [batch_size, - self.in_chans, height, width] + image: Input 4D tensor of shape `(N, in_chans, H, W)`. Returns: - (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, - height, width] + Output tensor of shape `(N, out_chans, H, W)`. """ return self.layers(image) - def __repr__(self): - return ( - f"ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans}, " - f"drop_prob={self.drop_prob})" - ) - class TransposeConvBlock(nn.Module): """ @@ -158,11 +152,11 @@ class TransposeConvBlock(nn.Module): layers followed by instance normalization and LeakyReLU activation. """ - def __init__(self, in_chans, out_chans): + def __init__(self, in_chans: int, out_chans: int): """ Args: - in_chans (int): Number of channels in the input. - out_chans (int): Number of channels in the output. + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. """ super().__init__() @@ -177,17 +171,12 @@ def __init__(self, in_chans, out_chans): nn.LeakyReLU(negative_slope=0.2, inplace=True), ) - def forward(self, image): + def forward(self, image: torch.Tensor) -> torch.Tensor: """ Args: - image (torch.Tensor): Input tensor of shape [batch_size, - self.in_chans, height, width] + image: Input 4D tensor of shape `(N, in_chans, H, W)`. Returns: - (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, - height, width] + Output tensor of shape `(N, out_chans, H*2, W*2)`. """ return self.layers(image) - - def __repr__(self): - return f"ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans})" diff --git a/fastmri/models/varnet.py b/fastmri/models/varnet.py index 8c553a7c..f90ab0cf 100644 --- a/fastmri/models/varnet.py +++ b/fastmri/models/varnet.py @@ -6,12 +6,12 @@ """ import math +from typing import Tuple +import fastmri import torch import torch.nn as nn import torch.nn.functional as F - -import fastmri from fastmri.data import transforms from .unet import Unet @@ -26,17 +26,21 @@ class NormUnet(nn.Module): during training. """ - def __init__(self, chans, num_pools, in_chans=2, out_chans=2, drop_prob=0): + def __init__( + self, + chans: int, + num_pools: int, + in_chans: int = 2, + out_chans: int = 2, + drop_prob: float = 0.0, + ): """ Args: - chans (int): Number of output channels of the first convolution - layer. - num_pools (int): Number of down-sampling and up-sampling layers. - in_chans (int, default=2): Number of channels in the input to the - U-Net model. - out_chans (int, default=2): Number of channels in the output to the - U-Net model. - drop_prob (float, default=0): Dropout probability. + chans: Number of output channels of the first convolution layer. + num_pools: Number of down-sampling and up-sampling layers. + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + drop_prob: Dropout probability. """ super().__init__() @@ -48,18 +52,18 @@ def __init__(self, chans, num_pools, in_chans=2, out_chans=2, drop_prob=0): drop_prob=drop_prob, ) - def complex_to_chan_dim(self, x): + def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w, two = x.shape assert two == 2 return x.permute(0, 4, 1, 2, 3).contiguous().view(b, 2 * c, h, w) - def chan_complex_to_last_dim(self, x): + def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor: b, c2, h, w = x.shape assert c2 % 2 == 0 c = c2 // 2 return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous() - def norm(self, x): + def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Group norm b, c, h, w = x.shape x = x.contiguous().view(b, 2, c // 2 * h * w) @@ -83,10 +87,14 @@ def norm(self, x): return (x - mean) / std, mean, std - def unnorm(self, x, mean, std): + def unnorm( + self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor + ) -> torch.Tensor: return x * std + mean - def pad(self, x): + def pad( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, Tuple[Tuple[int, int], Tuple[int, int], int, int]]: def floor_ceil(n): return math.floor(n), math.ceil(n) @@ -99,14 +107,25 @@ def floor_ceil(n): return x, (h_pad, w_pad, h_mult, w_mult) - def unpad(self, x, h_pad, w_pad, h_mult, w_mult): + def unpad( + self, + x: torch.Tensor, + h_pad: Tuple[int, int], + w_pad: Tuple[int, int], + h_mult: int, + w_mult: int, + ) -> torch.Tensor: return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]] - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + # get shapes for unet and normalize x = self.complex_to_chan_dim(x) x, mean, std = self.norm(x) x, pad_sizes = self.pad(x) + x = self.unet(x) + + # get shapes back and unnormalize x = self.unpad(x, *pad_sizes) x = self.unnorm(x, mean, std) x = self.chan_complex_to_last_dim(x) @@ -123,17 +142,21 @@ class SensitivityModel(nn.Module): end-to-end variational network. """ - def __init__(self, chans, num_pools, in_chans=2, out_chans=2, drop_prob=0): + def __init__( + self, + chans: int, + num_pools: int, + in_chans: int = 2, + out_chans: int = 2, + drop_prob: float = 0.0, + ): """ Args: - chans (int): Number of output channels of the first convolution - layer. - num_pools (int): Number of down-sampling and up-sampling layers. - in_chans (int, default=2): Number of channels in the input to the - U-Net model. - out_chans (int, default=2): Number of channels in the output to the - U-Net model. - drop_prob (float, default=0): Dropout probability. + chans: Number of output channels of the first convolution layer. + num_pools: Number of down-sampling and up-sampling layers. + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + drop_prob: Dropout probability. """ super().__init__() @@ -145,37 +168,38 @@ def __init__(self, chans, num_pools, in_chans=2, out_chans=2, drop_prob=0): drop_prob=drop_prob, ) - def chans_to_batch_dim(self, x): + def chans_to_batch_dim(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: b, c, *other = x.shape return x.contiguous().view(b * c, 1, *other), b - def batch_chans_to_chan_dim(self, x, batch_size): + def batch_chans_to_chan_dim(self, x: torch.Tensor, batch_size: int) -> torch.Tensor: bc, _, *other = x.shape c = bc // batch_size return x.view(batch_size, c, *other) - def divide_root_sum_of_squares(self, x): + def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor: return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1) - def forward(self, masked_kspace, mask): - def get_low_frequency_lines(mask): - l = r = mask.shape[-2] // 2 - while mask[..., r, :]: - r += 1 - - while mask[..., l, :]: - l -= 1 - - return l + 1, r + def forward(self, masked_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # get low frequency line locations and mask them out + left = right = mask.shape[-2] // 2 + while mask[..., right, :]: + right += 1 - l, r = get_low_frequency_lines(mask) - num_low_freqs = r - l + while mask[..., left, :]: + left -= 1 + num_low_freqs = right - left pad = (mask.shape[-2] - num_low_freqs + 1) // 2 + x = transforms.mask_center(masked_kspace, pad, pad + num_low_freqs) + + # convert to image space x = fastmri.ifft2c(x) x, b = self.chans_to_batch_dim(x) + + # estimate sensitivities x = self.norm_unet(x) x = self.batch_chans_to_chan_dim(x, b) x = self.divide_root_sum_of_squares(x) @@ -191,18 +215,24 @@ class VarNet(nn.Module): regularizer. To use non-U-Net regularizers, use VarNetBock. """ - def __init__(self, num_cascades=12, sens_chans=8, sens_pools=4, chans=18, pools=4): + def __init__( + self, + num_cascades: int = 12, + sens_chans: int = 8, + sens_pools: int = 4, + chans: int = 18, + pools: int = 4, + ): """ Args: - num_cascades (int, default=12): Number of cascades (i.e., layers) - for variational network. - sens_chans (int, default=8): Number of channels for sensitivity map + num_cascades: Number of cascades (i.e., layers) for variational + network. + sens_chans: Number of channels for sensitivity map U-Net. + sens_pools Number of downsampling and upsampling layers for + sensitivity map U-Net. + chans: Number of channels for cascade U-Net. + pools: Number of downsampling and upsampling layers for cascade U-Net. - sens_pools (int, default=8): Number of downsampling and upsampling - layers for sensitivity map U-Net. - chans (int, default=18): Number of channels for cascade U-Net. - pools (int, default=4): Number of downsampling and upsampling - layers for cascade U-Net. """ super().__init__() @@ -211,7 +241,7 @@ def __init__(self, num_cascades=12, sens_chans=8, sens_pools=4, chans=18, pools= [VarNetBlock(NormUnet(chans, pools)) for _ in range(num_cascades)] ) - def forward(self, masked_kspace, mask): + def forward(self, masked_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: sens_maps = self.sens_net(masked_kspace, mask) kspace_pred = masked_kspace.clone() @@ -230,33 +260,37 @@ class VarNetBlock(nn.Module): the full variational network. """ - def __init__(self, model): + def __init__(self, model: nn.Module): """ Args: - model (torch.nn.Module): Module for "regularization" component of - variational network. + model: Module for "regularization" component of variational + network. """ super().__init__() self.model = model self.dc_weight = nn.Parameter(torch.ones(1)) - self.register_buffer("zero", torch.zeros(1, 1, 1, 1, 1)) - - def forward(self, current_kspace, ref_kspace, mask, sens_maps): - def sens_expand(x): - return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) - def sens_reduce(x): - x = fastmri.ifft2c(x) - return fastmri.complex_mul(x, fastmri.complex_conj(sens_maps)).sum( - dim=1, keepdim=True - ) + def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: + return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) - def soft_dc(x): - return torch.where(mask, x - ref_kspace, self.zero) * self.dc_weight + def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: + x = fastmri.ifft2c(x) + return fastmri.complex_mul(x, fastmri.complex_conj(sens_maps)).sum( + dim=1, keepdim=True + ) - return ( - current_kspace - - soft_dc(current_kspace) - - sens_expand(self.model(sens_reduce(current_kspace))) + def forward( + self, + current_kspace: torch.Tensor, + ref_kspace: torch.Tensor, + mask: torch.Tensor, + sens_maps: torch.Tensor, + ) -> torch.Tensor: + zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace) + soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight + model_term = self.sens_expand( + self.model(self.sens_reduce(current_kspace, sens_maps)), sens_maps ) + + return current_kspace - soft_dc - model_term diff --git a/fastmri/pl_modules/data_module.py b/fastmri/pl_modules/data_module.py index d0f3f2fb..2f74307f 100644 --- a/fastmri/pl_modules/data_module.py +++ b/fastmri/pl_modules/data_module.py @@ -5,8 +5,9 @@ LICENSE file in the root directory of this source tree. """ -import pathlib from argparse import ArgumentParser +from pathlib import Path +from typing import Callable, Optional import fastmri import pytorch_lightning as pl @@ -30,45 +31,36 @@ class FastMriDataModule(pl.LightningDataModule): def __init__( self, - data_path, - challenge, - train_transform, - val_transform, - test_transform, - test_split="test", - test_path=None, - sample_rate=1.0, - batch_size=1, - num_workers=4, - distributed_sampler=False, + data_path: Path, + challenge: str, + train_transform: Callable, + val_transform: Callable, + test_transform: Callable, + test_split: str = "test", + test_path: Optional[Path] = None, + sample_rate: float = 1.0, + batch_size: int = 1, + num_workers: int = 4, + distributed_sampler: bool = False, ): """ Args: - data_path (pathlib.Path): Path to root data directory. For example, - if knee/path is the root directory with subdirectories - multicoil_train and multicoil_val, you would input knee/path - for data_path. - challenge (str): Name of challenge from ('multicoil', - 'singlecoil'). - train_transform (Callable): A transform object for the training - split. - val_transform (Callable): A transform object for the validation - split. - test_transform (Callable): A transform object for the test split. - test_split (str, optional): Name of test split from ("test", - "challenge"). Defaults to "test". - test_path (pathlib.Path, optional): An optional test path. - Passing this overwrites data_path and test_split. Defaults to - None. - sample_rate (float, optional): Fraction of of the training data - split to use. Can be set to less than 1.0 for rapid - prototyping. Defaults to 1.0. - batch_size (int, optional): Batch size. Defaults to 1. - num_workers (int, optional): Number of workers for PyTorch - dataloader. Defaults to 4. - distributed_sampler (bool, optional): Whether to use a distributed - sampler. This should be set to True if training with ddp. - Defaults to False. + data_path: Path to root data directory. For example, if knee/path + is the root directory with subdirectories multicoil_train and + multicoil_val, you would input knee/path for data_path. + challenge: Name of challenge from ('multicoil', 'singlecoil'). + train_transform: A transform object for the training split. + val_transform: A transform object for the validation split. + test_transform: A transform object for the test split. + test_split: Name of test split from ("test", "challenge"). + test_path: An optional test path. Passing this overwrites data_path + and test_split. + sample_rate: Fraction of of the training data split to use. Can be + set to less than 1.0 for rapid prototyping. + batch_size: Batch size. + num_workers: Number of workers for PyTorch dataloader. + distributed_sampler: Whether to use a distributed sampler. This + should be set to True if training with ddp. """ super().__init__() @@ -84,7 +76,12 @@ def __init__( self.num_workers = num_workers self.distributed_sampler = distributed_sampler - def _create_data_loader(self, data_transform, data_partition, sample_rate=None): + def _create_data_loader( + self, + data_transform: Callable, + data_partition: str, + sample_rate: Optional[float] = None, + ) -> torch.utils.data.DataLoader: if data_partition == "train": is_train = True sample_rate = sample_rate or self.sample_rate @@ -144,15 +141,12 @@ def add_data_specific_args(parent_parser): # pragma: no-cover # dataset arguments parser.add_argument( - "--data_path", - default=None, - type=pathlib.Path, - help="Path to fastMRI data root", + "--data_path", default=None, type=Path, help="Path to fastMRI data root", ) parser.add_argument( "--test_path", default=None, - type=pathlib.Path, + type=Path, help="Path to data for test mode. This overwrites data_path and test_split", ) parser.add_argument( diff --git a/fastmri/pl_modules/mri_module.py b/fastmri/pl_modules/mri_module.py index e5597093..744f9f30 100644 --- a/fastmri/pl_modules/mri_module.py +++ b/fastmri/pl_modules/mri_module.py @@ -22,7 +22,7 @@ def __init__(self, dist_sync_on_step=True): self.add_state("quantity", default=torch.tensor(0.0), dist_reduce_fx="sum") - def update(self, batch): + def update(self, batch: torch.Tensor): # type: ignore self.quantity += batch def compute(self): @@ -49,11 +49,10 @@ class MriModule(pl.LightningModule): Other methods from LightningModule can be overridden as needed. """ - def __init__(self, num_log_images=16): + def __init__(self, num_log_images: int = 16): """ Args: - num_log_images (int, optional): Number of images to log. Defaults - to 16. + num_log_images: Number of images to log. Defaults to 16. """ super().__init__() diff --git a/fastmri/pl_modules/varnet_module.py b/fastmri/pl_modules/varnet_module.py index 767e6227..f620f183 100644 --- a/fastmri/pl_modules/varnet_module.py +++ b/fastmri/pl_modules/varnet_module.py @@ -22,36 +22,31 @@ class VarNetModule(MriModule): def __init__( self, - num_cascades=12, - pools=4, - chans=18, - sens_pools=4, - sens_chans=8, - lr=0.0003, - lr_step_size=40, - lr_gamma=0.1, - weight_decay=0.0, + num_cascades: int = 12, + pools: int = 4, + chans: int = 18, + sens_pools: int = 4, + sens_chans: int = 8, + lr: float = 0.0003, + lr_step_size: int = 40, + lr_gamma: float = 0.1, + weight_decay: float = 0.0, **kwargs, ): """ Args: - num_cascades (int, optional): Number of cascades (i.e., layers) - for variational network. Defaults to 12. - pools (int, optional): Number of downsampling and upsampling - layers for cascade U-Net. Defaults to 4. - chans (int, optional): Number of channels for cascade U-Net. - Defaults to 18. - sens_pools (int, optional): Number of downsampling and upsampling - layers for sensitivity map U-Net. Defaults to 4. - sens_chans (int, optional): Number of channels for sensitivity map - U-Net. Defaults to 8. - lr (float, optional): Learning rate. Defaults to 0.0003. - lr_step_size (int, optional): Learning rate step size. Defaults to - 40. - lr_gamma (float, optional): Learning rate gamma decay. Defaults to - 0.0. - weight_decay (float, optional): Parameter for penalizing weights - norm. Defaults to 0.0. + num_cascades: Number of cascades (i.e., layers) for variational + network. + pools: Number of downsampling and upsampling layers for cascade + U-Net. + chans: Number of channels for cascade U-Net. + sens_pools: Number of downsampling and upsampling layers for + sensitivity map U-Net. + sens_chans: Number of channels for sensitivity map U-Net. + lr: Learning rate. + lr_step_size: Learning rate step size. + lr_gamma: Learning rate gamma decay. + weight_decay: Parameter for penalizing weights norm. """ super().__init__(**kwargs) self.save_hyperparameters() diff --git a/fastmri/utils.py b/fastmri/utils.py index 0890dce2..617a20fa 100644 --- a/fastmri/utils.py +++ b/fastmri/utils.py @@ -5,10 +5,14 @@ LICENSE file in the root directory of this source tree. """ +from pathlib import Path +from typing import Dict + import h5py +import numpy as np -def save_reconstructions(reconstructions, out_dir): +def save_reconstructions(reconstructions: Dict[str, np.ndarray], out_dir: Path): """ Save reconstruction images. @@ -16,11 +20,10 @@ def save_reconstructions(reconstructions, out_dir): leaderboard. Args: - reconstructions (dict[str, np.array]): A dictionary mapping input - filenames to corresponding reconstructions (of shape num_slices x - height x width). - out_dir (pathlib.Path): Path to the output directory where the - reconstructions should be saved. + reconstructions: A dictionary mapping input filenames to corresponding + reconstructions. + out_dir: Path to the output directory where the reconstructions should + be saved. """ out_dir.mkdir(exist_ok=True, parents=True) for fname, recons in reconstructions.items(): diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..21ea37ac --- /dev/null +++ b/mypy.ini @@ -0,0 +1,9 @@ +[mypy] + +# modules that don't play well with mypy +[mypy-numpy.*,h5py.*,runstats.*,skimage.metrics.*,setuptools.*,bart.*,torchvision.*,tqdm.*] +ignore_missing_imports=True + +# directories we're not tracking +[mypy-tests.*,models.*,data.*,common.*,banding_removal.*] +ignore_missing_imports=True \ No newline at end of file