Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jan 19, 2024
1 parent 40178ca commit 866bc5b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
3 changes: 2 additions & 1 deletion direct/nn/unet/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

from dataclasses import dataclass
from typing import Optional

from direct.config.defaults import ModelConfig

Expand Down
4 changes: 4 additions & 0 deletions direct/nn/vsharp/vsharp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ def _do_iteration(
output_kspace: TensorOrNone

with autocast(enabled=self.mixed_precision):
if self.cfg.model.conv_modulation: # type: ignore
data["auxiliary_data"] = torch.cat([data["acceleration"], data["center_fraction"]], 1)

output_images, output_kspace = self.forward_function(data)
output_images = [T.modulus_if_complex(_, complex_axis=self._complex_dim) for _ in output_images]
loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()}
Expand Down Expand Up @@ -251,6 +254,7 @@ def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, None]:
masked_kspace=data["masked_kspace"],
sampling_mask=data["sampling_mask"],
sensitivity_map=data["sensitivity_map"],
auxiliary_data=data["auxiliary_data"] if self.cfg.model.conv_modulation else None,
) # shape (batch, height, width, complex[=2])

output_image = output_images[-1]
Expand Down
11 changes: 11 additions & 0 deletions tests/tests_common/test_subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,14 @@ def test_apply_mask_cartesian(mask_func, shape, center_fractions, accelerations)
mask_func = mask_func(center_fractions=center_fractions, accelerations=accelerations)
mask = mask_func(shape[1:], seed=123)
acs_mask = mask_func(shape[1:], seed=123, return_acs=True)
_, acceleration, center_fraction = mask_func(shape[1:], seed=123, return_acceleration=True)
expected_mask_shape = (1, shape[1], shape[2], 1)

assert mask.max() == 1
assert mask.min() == 0
assert mask.shape == expected_mask_shape
assert np.allclose(mask & acs_mask, acs_mask)
assert acceleration in accelerations and center_fraction in center_fractions


@pytest.mark.parametrize(
Expand Down Expand Up @@ -186,11 +188,13 @@ def test_apply_mask_calgary_campinas(shape, accelerations):
mask = mask_func(shape[1:], seed=123)
acs_mask = mask_func(shape[1:], seed=123, return_acs=True)
expected_mask_shape = (1, shape[1], shape[2], 1)
_, acceleration, _ = mask_func(shape[1:], seed=123, return_acceleration=True)

assert mask.max() == 1
assert mask.min() == 0
assert mask.shape == expected_mask_shape
assert acs_mask.shape == expected_mask_shape
assert acceleration in accelerations


@pytest.mark.parametrize(
Expand Down Expand Up @@ -223,11 +227,13 @@ def test_apply_mask_radial(shape, accelerations):
mask = mask_func(shape[1:], seed=123)
acs_mask = mask_func(shape[1:], seed=123, return_acs=True)
expected_mask_shape = (1, shape[1], shape[2], 1)
acceleration = mask_func(shape[1:], seed=123, return_acceleration=True)[1]

assert mask.max() == 1
assert mask.min() == 0
assert mask.shape == expected_mask_shape
assert np.allclose(mask & acs_mask, acs_mask)
assert acceleration in accelerations


@pytest.mark.parametrize(
Expand Down Expand Up @@ -261,11 +267,13 @@ def test_apply_mask_spiral(shape, accelerations):
mask = mask_func(shape[1:], seed=123)
acs_mask = mask_func(shape[1:], seed=123, return_acs=True)
expected_mask_shape = (1, shape[1], shape[2], 1)
acceleration = mask_func(shape[1:], seed=123, return_acceleration=True)[1]

assert mask.max() == 1
assert mask.min() == 0
assert mask.shape == expected_mask_shape
assert np.allclose(mask & acs_mask, acs_mask)
assert acceleration in accelerations


@pytest.mark.parametrize(
Expand Down Expand Up @@ -311,10 +319,13 @@ def test_apply_mask_poisson(shape, accelerations, center_fractions, seed):
)
mask = mask_func(shape[1:], seed=seed)
acs_mask = mask_func(shape[1:], seed=seed, return_acs=True)
_, acceleration, center_fraction = mask_func(shape[1:], seed=123, return_acceleration=True)

expected_mask_shape = (1, shape[1], shape[2], 1)
assert mask.max() == 1
assert mask.min() == 0
assert mask.shape == expected_mask_shape
assert acceleration in accelerations and center_fraction in center_fractions
if seed is not None:
assert np.allclose(mask & acs_mask, acs_mask)

Expand Down

0 comments on commit 866bc5b

Please sign in to comment.