Skip to content

Commit

Permalink
Add typing (#91)
Browse files Browse the repository at this point in the history
* Add typing to data folder

* Add typing to models

* Better comments

* Add typing to pl_modules

* Typing for base module, config

* Test fixes

* Add mypy test

* Add dev requirements

* Adding a few more ignores for mypy

* Single install cache
  • Loading branch information
mmuckley authored Oct 27, 2020
1 parent a29123f commit 0abc8e2
Show file tree
Hide file tree
Showing 17 changed files with 627 additions and 481 deletions.
7 changes: 2 additions & 5 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@ jobs:
executor: python/default
steps:
- checkout
- run:
name: Preinstallation Packages
command: |
pip install wheel
pip install pytest
- python/install-packages:
pip-dependency-file: dev-requirements.txt
pkg-manager: pip
- run:
name: Install fastMRI
Expand All @@ -25,6 +21,7 @@ jobs:
command: |
pytest --version
pytest tests
mypy fastmri
name: Test

workflows:
Expand Down
4 changes: 4 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-r requirements.txt
wheel
pytest
mypy
16 changes: 8 additions & 8 deletions fastmri/coil_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,33 @@
import fastmri


def rss(data, dim=0):
def rss(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
"""
Compute the Root Sum of Squares (RSS).
RSS is computed assuming that dim is the coil dimension.
Args:
data (torch.Tensor): The input tensor
dim (int): The dimensions along which to apply the RSS transform
data: The input tensor
dim: The dimensions along which to apply the RSS transform
Returns:
torch.Tensor: The RSS value.
The RSS value.
"""
return torch.sqrt((data ** 2).sum(dim))


def rss_complex(data, dim=0):
def rss_complex(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
"""
Compute the Root Sum of Squares (RSS) for complex inputs.
RSS is computed assuming that dim is the coil dimension.
Args:
data (torch.Tensor): The input tensor
dim (int): The dimensions along which to apply the RSS transform
data: The input tensor
dim: The dimensions along which to apply the RSS transform
Returns:
torch.Tensor: The RSS value.
The RSS value.
"""
return torch.sqrt(fastmri.complex_abs_sq(data).sum(dim))
138 changes: 82 additions & 56 deletions fastmri/data/mri_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
"""

import logging
import pathlib
import os
import pickle
import random
import xml.etree.ElementTree as etree
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
from warnings import warn

import h5py
Expand All @@ -18,20 +20,25 @@
import yaml


def et_query(root, qlist, namespace="http://www.ismrm.org/ISMRMRD"):
def et_query(
root: etree.Element,
qlist: Sequence[str],
namespace: str = "http://www.ismrm.org/ISMRMRD",
) -> str:
"""
ElementTree query function.
This can be used to query an xml document via ElementTree. It uses qlist
for nexted queries.
Args:
root (xml.etree.ElementTree.Element): Root of the xml.
qlist (Sequence): A list of strings for nested searches.
namespace (str): xml namespace.
root: Root of the xml to search through.
qlist: A list of strings for nested searches, e.g. ["Encoding",
"matrixSize"]
namespace: Optional; xml namespace to prepend query.
Returns:
str: The retrieved data.
The retrieved data as a string.
"""
s = "."
prefix = "ismrmrd_namespace"
Expand All @@ -41,10 +48,16 @@ def et_query(root, qlist, namespace="http://www.ismrm.org/ISMRMRD"):
for el in qlist:
s = s + f"//{prefix}:{el}"

return root.find(s, ns).text
value = root.find(s, ns)
if value is None:
raise RuntimeError("Element not found")

return str(value.text)

def fetch_dir(key, data_config_file=pathlib.Path("fastmri_dirs.yaml")):

def fetch_dir(
key: str, data_config_file: Union[str, Path, os.PathLike] = "fastmri_dirs.yaml"
) -> Path:
"""
Data directory fetcher.
Expand All @@ -53,14 +66,15 @@ def fetch_dir(key, data_config_file=pathlib.Path("fastmri_dirs.yaml")):
and this function will retrieve the requested subsplit of the data for use.
Args:
key (str): key to retrieve path from data_config_file.
data_config_file (pathlib.Path,
default=pathlib.Path("fastmri_dirs.yaml")): Default path config
file.
key: key to retrieve path from data_config_file. Expected to be in
("knee_path", "brain_path", "log_path").
data_config_file: Optional; Default path config file to fetch path
from.
Returns:
pathlib.Path: The path to the specified directory.
The path to the specified directory.
"""
data_config_file = Path(data_config_file)
if not data_config_file.is_file():
default_config = {
"knee_path": "/path/to/knee",
Expand All @@ -81,37 +95,44 @@ def fetch_dir(key, data_config_file=pathlib.Path("fastmri_dirs.yaml")):
with open(data_config_file, "r") as f:
data_dir = yaml.safe_load(f)[key]

data_dir = pathlib.Path(data_dir)

return data_dir
return Path(data_dir)


class CombinedSliceDataset(torch.utils.data.Dataset):
"""
A container for combining slice datasets.
Args:
roots (list of pathlib.Path): Paths to the datasets.
transforms (list of callable): A callable object that pre-processes the
raw data into appropriate form. The transform function should take
'kspace', 'target', 'attributes', 'filename', and 'slice' as
inputs. 'target' may be null for test data.
challenges (list of str): "singlecoil" or "multicoil" depending on which
challenge to use.
sample_rates (list of float, optional): A float between 0 and 1. This
controls what fraction of the volumes should be loaded.
num_cols (tuple(int), optional): if provided, only slices with the desired
number of columns will be considered.
"""

def __init__(self, roots, transforms, challenges, sample_rates=None, num_cols=None):
def __init__(
self,
roots: Sequence[Path],
transforms: Sequence[Callable],
challenges: Sequence[str],
sample_rates: Optional[Sequence[float]] = None,
num_cols: Optional[Tuple[int]] = None,
):
"""
Args:
roots: Paths to the datasets.
transforms: A callable object that preprocesses the raw data into
appropriate form. The transform function should take 'kspace',
'target', 'attributes', 'filename', and 'slice' as inputs.
'target' may be null for test data.
challenges: "singlecoil" or "multicoil" depending on which
challenge to use.
sample_rates: Optional; A float between 0 and 1. This controls
what fraction of the volumes should be loaded.
num_cols: Optional; If provided, only slices with the desired
number of columns will be considered.
"""
assert len(roots) == len(transforms) == len(challenges)
if sample_rates is not None:
assert len(sample_rates) == len(roots)
else:
sample_rates = [1] * len(roots)

self.datasets = list()
self.datasets = []
self.examples: List[Tuple[Path, int, Dict[str, object]]] = []
for i in range(len(roots)):
self.datasets.append(
SliceDataset(
Expand All @@ -123,6 +144,8 @@ def __init__(self, roots, transforms, challenges, sample_rates=None, num_cols=No
)
)

self.examples = self.examples + self.datasets[-1].examples

def __len__(self):
length = 0
for dataset in self.datasets:
Expand All @@ -141,36 +164,37 @@ def __getitem__(self, i):
class SliceDataset(torch.utils.data.Dataset):
"""
A PyTorch Dataset that provides access to MR image slices.
Args:
root (pathlib.Path): Path to the dataset.
transform (callable): A callable object that pre-processes the raw data
into appropriate form. The transform function should take 'kspace',
'target', 'attributes', 'filename', and 'slice' as inputs. 'target'
may be null for test data.
challenge (str): "singlecoil" or "multicoil" depending on which
challenge to use.
sample_rate (float, optional): A float between 0 and 1. This controls
what fraction of the volumes should be loaded.
dataset_cache_file (pathlib.Path). A file in which to cache dataset
information for faster load times. Default: dataset_cache.pkl.
num_cols (tuple(int), optional): if provided, only slices with the desired
number of columns will be considered.
"""

def __init__(
self,
root,
transform,
challenge,
sample_rate=1,
dataset_cache_file=pathlib.Path("dataset_cache.pkl"),
num_cols=None,
root: Union[str, Path, os.PathLike],
transform: Callable,
challenge: str,
sample_rate: float = 1.0,
dataset_cache_file: Union[str, Path, os.PathLike] = "dataset_cache.pkl",
num_cols: Optional[Tuple[int]] = None,
):
"""
Args:
root: Path to the dataset.
transform: A callable object that pre-processes the raw data into
appropriate form. The transform function should take 'kspace',
'target', 'attributes', 'filename', and 'slice' as inputs.
'target' may be null for test data.
challenge: "singlecoil" or "multicoil" depending on which challenge
to use.
sample_rate: Optional; A float between 0 and 1. This controls what
fraction of the volumes should be loaded. Defaults to 1.0.
dataset_cache_file: Optional; A file in which to cache dataset
information for faster load times.
num_cols: Optional; If provided, only slices with the desired
number of columns will be considered.
"""
if challenge not in ("singlecoil", "multicoil"):
raise ValueError('challenge should be either "singlecoil" or "multicoil"')

self.dataset_cache_file = dataset_cache_file
self.dataset_cache_file = Path(dataset_cache_file)

self.transform = transform
self.recons_key = (
Expand All @@ -185,7 +209,7 @@ def __init__(
dataset_cache = {}

if dataset_cache.get(root) is None:
files = list(pathlib.Path(root).iterdir())
files = list(Path(root).iterdir())
for fname in sorted(files):
with h5py.File(fname, "r") as hf:
et_root = etree.fromstring(hf["ismrmrd_header"][()])
Expand Down Expand Up @@ -238,13 +262,15 @@ def __init__(

if num_cols:
self.examples = [
ex for ex in self.examples if ex[2]["encoding_size"][1] in num_cols
ex
for ex in self.examples
if ex[2]["encoding_size"][1] in num_cols # type: ignore
]

def __len__(self):
return len(self.examples)

def __getitem__(self, i):
def __getitem__(self, i: int):
fname, dataslice, metadata = self.examples[i]

with h5py.File(fname, "r") as hf:
Expand Down
Loading

0 comments on commit 0abc8e2

Please sign in to comment.