Skip to content

Commit

Permalink
Change loupe to two-dim sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Nov 30, 2023
1 parent 276efef commit e5ad693
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 18 deletions.
4 changes: 2 additions & 2 deletions direct/nn/adaptive/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class PolicyConfig(ModelConfig):

@dataclass
class LOUPEPolicyConfig(PolicyConfig):
num_actions: int = MISSING
kspace_shape: tuple[int, ...] = MISSING


@dataclass
Expand All @@ -35,7 +35,7 @@ class LOUPE3dPolicyConfig(PolicyConfig):

@dataclass
class MultiStraightThroughPolicyConfig(PolicyConfig):
image_size: tuple[int, int] = MISSING
kspace_shape: tuple[int, int] = MISSING
num_layers: int = 2
num_fc_layers: int = 3
fc_size: int = 256
Expand Down
52 changes: 36 additions & 16 deletions direct/nn/adaptive/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import Callable, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -25,7 +26,7 @@ def __init__(
self,
acceleration: float,
center_fraction: float,
num_actions: int,
kspace_shape: tuple[int, ...],
use_softplus: bool = True,
slope: float = 10,
fix_sign_leakage: bool = True,
Expand All @@ -34,7 +35,12 @@ def __init__(
st_clamp: bool = False,
):
super().__init__()
# shape = [1, W]
if len(kspace_shape) not in [1, 2]:
raise ValueError(
f"Input dimension of LOUPEPolicy should have lenght of 1 or 2. Received `input_dim`={input_dim}."
)
self.dim = len(kspace_shape)
num_actions = np.prod(kspace_shape)
self.use_softplus = use_softplus
self.slope = slope
self.st_slope = st_slope
Expand All @@ -61,9 +67,13 @@ def __init__(

def forward(self, mask: torch.Tensor, kspace: torch.Tensor, padding: Optional[torch.Tensor] = None):
batch_size, _, height, width, _ = kspace.shape # batch, coils, height, width, complex
mask = mask[:, :, 0, :, :].reshape(batch_size, 1, 1, width, 1)
masks = [mask]
# Reshape to [B, W]
if self.dim == 1:
mask = mask[:, :, 0, :, :].reshape(batch_size, width)
else:
mask = mask.reshape(batch_size, height * width)

# Reshape to [B, num_actions]
sampler_out = self.sampler.expand(batch_size, -1)
if self.use_softplus:
# Softplus to make positive
Expand All @@ -79,10 +89,13 @@ def forward(self, mask: torch.Tensor, kspace: torch.Tensor, padding: Optional[to
masked_prob_mask = prob_mask * (1 - mask.reshape(prob_mask.shape[0], prob_mask.shape[1]))
# Mask out padded areas
if padding is not None:
padding = padding[:, :, 0, :, :].reshape(batch_size, width)
if self.dim == 1:
padding = padding[:, :, 0, :, :].reshape(batch_size, width)
else:
padding = padding.reshape(batch_size, height * width)
masked_prob_mask = masked_prob_mask * (1 - padding)
# Take out zero (masked) probabilities, since we don't want to include those in the normalisation
nonzero_idcs = (mask.view(batch_size, width) == 0).nonzero(as_tuple=True)
nonzero_idcs = (mask == 0).nonzero(as_tuple=True)
probs_to_norm = masked_prob_mask[nonzero_idcs].reshape(batch_size, -1)
# Rescale probabilities to desired sparsity.
normed_probs = rescale_probs(probs_to_norm, self.budget)
Expand All @@ -92,11 +105,18 @@ def forward(self, mask: torch.Tensor, kspace: torch.Tensor, padding: Optional[to
# Binarize the mask
flat_bin_mask = self.binarizer(masked_prob_mask)

# BCHW --> BW --> B11W1 --> B1HW1
acquisitions = flat_bin_mask.reshape(batch_size, 1, 1, width, 1).expand(batch_size, 1, height, width, 1)
final_prob_mask = masked_prob_mask.reshape(batch_size, 1, 1, width, 1).expand(batch_size, 1, height, width, 1)
# B11H1
mask = mask.expand(batch_size, 1, height, width, 1)
if self.dim == 1:
# BCHW --> BW --> B11W1 --> B1HW1
acquisitions = flat_bin_mask.reshape(batch_size, 1, 1, width, 1).expand(batch_size, 1, height, width, 1)
final_prob_mask = masked_prob_mask.reshape(batch_size, 1, 1, width, 1).expand(
batch_size, 1, height, width, 1
)
mask = mask.reshape(batch_size, 1, 1, width, 1).expand(batch_size, 1, height, width, 1)
else:
# BCHW --> BH*W --> B1HW1
acquisitions = flat_bin_mask.reshape(batch_size, 1, height, width, 1)
final_prob_mask = masked_prob_mask.reshape(batch_size, 1, height, width, 1)
mask = mask.reshape(batch_size, 1, height, width, 1)
mask = mask + acquisitions
masks.append(mask)
# BMHWC
Expand Down Expand Up @@ -214,7 +234,7 @@ def __init__(
self,
budget: int,
backward_operator: Callable,
image_size: tuple[int, int] = (128, 128),
kspace_shape: tuple[int, int] = (128, 128),
slope: float = 10,
sampler_detach_mask: bool = False,
kspace_sampler: bool = False,
Expand All @@ -231,7 +251,7 @@ def __init__(
super().__init__()

self.sampler = (KSpaceLineConvSampler if kspace_sampler else ImageLineConvSampler)(
input_dim=(2, *image_size),
input_dim=(2, *kspace_shape),
slope=slope,
use_softplus=use_softplus,
fc_size=fc_size,
Expand Down Expand Up @@ -329,7 +349,7 @@ def __init__(
center_fraction: float,
backward_operator: Callable,
num_layers: int = 1,
image_size: tuple[int, int] = (128, 128),
kspace_shape: tuple[int, int] = (128, 128),
slope: float = 10,
kspace_sampler: bool = False,
sampler_detach_mask: bool = False,
Expand All @@ -347,7 +367,7 @@ def __init__(

self.layers = nn.ModuleList()

num_cols = image_size[-1]
num_cols = kspace_shape[-1]
budget = int(num_cols / acceleration - num_cols * center_fraction)
layer_budget = budget // num_layers

Expand All @@ -359,7 +379,7 @@ def __init__(
StraightThroughPolicy(
budget=layer_budget,
backward_operator=backward_operator,
image_size=image_size,
kspace_shape=kspace_shape,
slope=slope,
sampler_detach_mask=sampler_detach_mask,
kspace_sampler=kspace_sampler,
Expand Down

0 comments on commit e5ad693

Please sign in to comment.