Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SSL and JSSL functionality #277

Merged
merged 31 commits into from
Apr 26, 2024
Merged
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b6f566a
Add base ssl functionality
georgeyiasemis Apr 11, 2024
540b3a2
Add SSL transforms
georgeyiasemis Apr 11, 2024
f0c9f1d
Updated ssl tests
georgeyiasemis Apr 11, 2024
8d46dbb
Minor fix
georgeyiasemis Apr 11, 2024
d7a79ea
Initial ssl models
georgeyiasemis Apr 11, 2024
2c7840c
Small changes for SSL and test
georgeyiasemis Apr 11, 2024
cc7ea05
black
georgeyiasemis Apr 11, 2024
db41ad9
black
georgeyiasemis Apr 11, 2024
b6cabb9
Import annotations
georgeyiasemis Apr 11, 2024
0f76fc9
Add docstrings
georgeyiasemis Apr 11, 2024
fbe66a2
Minor docstrings changes
georgeyiasemis Apr 16, 2024
7eaba03
Add JSSL initial engines
georgeyiasemis Apr 17, 2024
92c0ef9
SSL fixes
georgeyiasemis Apr 17, 2024
57ca113
Unet ssl/jssl tests
georgeyiasemis Apr 18, 2024
851f289
is_ssl_training -> is_ssl
georgeyiasemis Apr 18, 2024
c8c5532
Varnet ssl and jssl engines
georgeyiasemis Apr 18, 2024
d31199e
Minor fix
georgeyiasemis Apr 18, 2024
6404ac7
Minor fixes
georgeyiasemis Apr 18, 2024
daa28c5
Code quality fixes
georgeyiasemis Apr 18, 2024
83ab7fb
Code quality fixes
georgeyiasemis Apr 18, 2024
49577ad
Code quality fixes
georgeyiasemis Apr 18, 2024
036af26
Code quality fixes in mri transforms
georgeyiasemis Apr 18, 2024
5b2629a
Remove useless option - new pylint
georgeyiasemis Apr 18, 2024
dc0ae07
Codacy quality fixes
georgeyiasemis Apr 18, 2024
67f0222
Where to put disable msg?
georgeyiasemis Apr 18, 2024
68e30bb
Add docstrings
georgeyiasemis Apr 19, 2024
8a52b24
Add reference
georgeyiasemis Apr 19, 2024
0d81470
Enum typing doesn't require checks
georgeyiasemis Apr 19, 2024
50aca9a
Omegaconf doesn't accept future annotations
georgeyiasemis Apr 19, 2024
2572474
Test fix
georgeyiasemis Apr 19, 2024
924bc5f
Minor fix
georgeyiasemis Apr 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
SSL fixes
georgeyiasemis committed Apr 17, 2024
commit 92c0ef955af0b06c179d9d4faae4d7fde355cb28
56 changes: 31 additions & 25 deletions direct/nn/ssl/mri_models.py
Original file line number Diff line number Diff line change
@@ -229,11 +229,10 @@ def _do_iteration(
if self.model.training:
# SSL: project the predicted k-space to target k-space, i.e. predict locations only in target k-space
output_kspace = T.apply_mask(output_kspace, data["target_sampling_mask"], return_mask=False)
# Compute loss and regularizer in k-space domain
loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, None, output_kspace)
regularizer_dict = self.compute_loss_on_data(
regularizer_dict, regularizer_fns, data, None, output_kspace
)

# Compute loss and regularizer in k-space domain
loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, None, output_kspace)
regularizer_dict = self.compute_loss_on_data(regularizer_dict, regularizer_fns, data, None, output_kspace)

# Compute image via SENSE reconstruction
output_image = T.modulus(
@@ -243,12 +242,17 @@ def _do_iteration(
self._coil_dim,
)
)

# Compute loss and regularizer loss in image domain
loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, output_image, None)
regularizer_dict = self.compute_loss_on_data(regularizer_dict, regularizer_fns, data, output_image, None)

# Compute total loss
loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore

# Backward pass
if self.model.training:
# Compute loss and regularizer loss in image domain
loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, output_image, None)
regularizer_dict = self.compute_loss_on_data(
regularizer_dict, regularizer_fns, data, output_image, None
)
self._scaler.scale(loss).backward()

loss_dict = detach_dict(loss_dict) # Detach dict, only used for logging.
regularizer_dict = detach_dict(regularizer_dict)
@@ -469,16 +473,13 @@ def _do_iteration(
# Data consistency (followed by padding if it exists)
output_kspace = T.apply_padding(kspace + output_kspace, padding=data.get("padding", None))

if self.model.training:
if is_ssl_training:
# SSL: project the predicted k-space to target k-space
output_kspace = T.apply_mask(output_kspace, data["target_sampling_mask"], return_mask=False)

# Compute loss and regularizer loss in k-space domain
loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, None, output_kspace)
regularizer_dict = self.compute_loss_on_data(
regularizer_dict, regularizer_fns, data, None, output_kspace
)
if self.model.training and is_ssl_training:
# SSL: project the predicted k-space to target k-space
output_kspace = T.apply_mask(output_kspace, data["target_sampling_mask"], return_mask=False)

# Compute loss and regularizer loss in k-space domain
loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, None, output_kspace)
regularizer_dict = self.compute_loss_on_data(regularizer_dict, regularizer_fns, data, None, output_kspace)

# Compute image via SENSE reconstruction
output_image = T.modulus(
@@ -488,12 +489,17 @@ def _do_iteration(
self._coil_dim,
)
)

# Compute loss and regularizer loss in image domain
loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, output_image, None)
regularizer_dict = self.compute_loss_on_data(regularizer_dict, regularizer_fns, data, output_image, None)

# Compute total loss
loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore

# Backward pass
if self.model.training:
# Compute loss and regularizer loss in image domain
loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, output_image, None)
regularizer_dict = self.compute_loss_on_data(
regularizer_dict, regularizer_fns, data, output_image, None
)
self._scaler.scale(loss).backward()

loss_dict = detach_dict(loss_dict) # Detach dict, only used for logging.
regularizer_dict = detach_dict(regularizer_dict)
4 changes: 2 additions & 2 deletions direct/nn/unet/unet_engine.py
Original file line number Diff line number Diff line change
@@ -271,9 +271,9 @@ def forward_function(self, data: dict[str, Any]) -> tuple[torch.Tensor, None]:
# Get the k-space and mask which differ if SSL training or supervised training
# The also differ during training and inference for SSL
if is_ssl_training and self.model.training:
kspace, mask = data["input_kspace"], data["input_sampling_mask"]
kspace = data["input_kspace"]
else:
kspace, mask = data["masked_kspace"], data["sampling_mask"]
kspace = data["masked_kspace"]

sensitity_map = (
data["sensitivity_map"] if self.cfg.model.image_initialization == "sense" else None # type: ignore