-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restormer Implementation for MONAI: High-Resolution Image Restoration #8261
Comments
@Nic-Ma and @ericspod and @KumoLiu - this seems like an outstanding addition to MONAI - agreed? @phisanti - if all approve, please look at our contribution guidelines. You are already doing the exact right thing by having a modular design. Whenever appropriate, please support the exploration of alternative components in this framework via that modular design and appropriate class abstractions. Please also include multiple tutorials and unit tests with your work. Does your code currently exist in another repo that we could preliminarily review? Thanks! |
You can take a look at the modular implementation of the Restormer architecture here. Also copied the code below. As you can see, I maintain many of the key blocks intact and focus on expanding functionality (Flash att), and adding modularity on the enc/dec blocks. I am happy to implement extra changes if a good suggestion is made. """
Restormer: Efficient Transformer for High-Resolution Image Restoration
Implementation based on: https://arxiv.org/abs/2111.09881
"""
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.networks.layers import Norm
from einops import rearrange
class FeedForward(nn.Module):
"""Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism.
Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection."""
def __init__(self, dim: int, ffn_expansion_factor: float, bias: bool):
super().__init__()
hidden_features = int(dim * ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3,
stride=1, padding=1, groups=hidden_features*2, bias=bias)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
return self.project_out(F.gelu(x1) * x2)
class Attention(nn.Module):
"""Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention
by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
convolutions for local mixing before attention, achieving linear complexity vs quadratic
in vanilla attention."""
def __init__(self, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
super().__init__()
if flash_attention and not hasattr(F, 'scaled_dot_product_attention'):
raise ValueError("Flash attention not available")
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.flash_attention = flash_attention
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1,
padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self._attention_fn = self._get_attention_fn()
def _get_attention_fn(self):
if self.flash_attention:
return self._flash_attention
return self._normal_attention
def _flash_attention(self, q, k, v):
"""Flash attention implementation using scaled dot-product attention."""
scale = float(self.temperature.mean())
out = F.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
dropout_p=0.0,
is_causal=False
)
return out
def _normal_attention(self, q, k, v):
"""Attention matrix multiplication with depth-wise convolutions."""
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
return attn @ v
def forward(self, x):
"""Forward pass for MDTA attention.
1. Apply depth-wise convolutions to Q, K, V
2. Reshape Q, K, V for multi-head attention
3. Compute attention matrix using flash or normal attention
4. Reshape and project out attention output"""
b,c,h,w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q,k,v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
out = self._attention_fn(q, k, v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
class TransformerBlock(nn.Module):
"""Basic transformer unit combining MDTA and GDFN with skip connections.
Unlike standard transformers that use LayerNorm, this block uses Instance Norm
for better adaptation to image restoration tasks."""
def __init__(self, dim: int, num_heads: int, ffn_expansion_factor: float,
bias: bool, LayerNorm_type: str, flash_attention: bool = False):
super().__init__()
use_bias = LayerNorm_type != 'BiasFree'
self.norm1 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias)
self.attn = Attention(dim, num_heads, bias, flash_attention)
self.norm2 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
#print(f'x shape in transformer block: {x.shape}')
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
class OverlapPatchEmbed(nn.Module):
"""Initial feature extraction using overlapped convolutions.
Unlike standard patch embeddings that use non-overlapping patches,
this approach maintains spatial continuity through 3x3 convolutions."""
def __init__(self, in_c: int = 3, embed_dim: int = 48, bias: bool = False):
super().__init__()
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3,
stride=1, padding=1, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)
class Downsample(nn.Module):
"""Downsampling module that halves spatial dimensions while doubling channels.
Uses PixelUnshuffle for efficient feature map manipulation."""
def __init__(self, n_feat: int):
super().__init__()
self.body = nn.Sequential(
nn.Conv2d(n_feat, n_feat//2, kernel_size=3,
stride=1, padding=1, bias=False),
nn.PixelUnshuffle(2)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.body(x)
class Upsample(nn.Module):
"""Upsampling module that doubles spatial dimensions while halving channels.
Combines convolution with PixelShuffle for efficient feature expansion."""
def __init__(self, in_channels: int) -> None:
super().__init__()
self.body = nn.Sequential(
nn.Conv2d(in_channels, in_channels * 2, kernel_size=3,
stride=1, padding=1, bias=False),
nn.PixelShuffle(2)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.body(x)
##---------- Restormer -----------------------
class Restormer(nn.Module):
"""Restormer: Efficient Transformer for High-Resolution Image Restoration.
Implements a U-Net style architecture with transformer blocks, combining:
- Multi-scale feature processing through progressive down/upsampling
- Efficient attention via MDTA blocks
- Local feature mixing through GDFN
- Skip connections for preserving spatial details
Architecture:
- Encoder: Progressive feature downsampling with increasing channels
- Latent: Deep feature processing at lowest resolution
- Decoder: Progressive upsampling with skip connections
- Refinement: Final feature enhancement
"""
def __init__(self,
inp_channels=3,
out_channels=3,
dim=48,
num_blocks=[1, 1, 1, 1],
heads=[1, 1, 1, 1],
num_refinement_blocks=4,
ffn_expansion_factor=2.66,
bias=False,
LayerNorm_type='WithBias',
dual_pixel_task=False,
flash_attention=False):
super().__init__()
"""Initialize Restormer model.
Args:
inp_channels: Number of input image channels
out_channels: Number of output image channels
dim: Base feature dimension
num_blocks: Number of transformer blocks at each scale
num_refinement_blocks: Number of final refinement blocks
heads: Number of attention heads at each scale
ffn_expansion_factor: Expansion factor for feed-forward network
bias: Whether to use bias in convolutions
LayerNorm_type: Type of normalization ('WithBias' or 'BiasFree')
dual_pixel_task: Enable dual-pixel specific processing
flash_attention: Use flash attention if available
"""
# Check input parameters
assert len(num_blocks) > 1, "Number of blocks must be greater than 1"
assert len(num_blocks) == len(heads), "Number of blocks and heads must be equal"
assert all([n > 0 for n in num_blocks]), "Number of blocks must be greater than 0"
# Initial feature extraction
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_levels = nn.ModuleList()
self.downsamples = nn.ModuleList()
self.decoder_levels = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.reduce_channels = nn.ModuleList()
num_steps = len(num_blocks) - 1
self.num_steps = num_steps
# Define encoder levels
for n in range(num_steps):
current_dim = dim * 2**n
self.encoder_levels.append(
nn.Sequential(*[
TransformerBlock(
dim=current_dim,
num_heads=heads[n],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
flash_attention=flash_attention
) for _ in range(num_blocks[n])
])
)
self.downsamples.append(Downsample(current_dim))
# Define latent space
latent_dim = dim * 2**num_steps
self.latent = nn.Sequential(*[
TransformerBlock(
dim=latent_dim,
num_heads=heads[num_steps],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
flash_attention=flash_attention
) for _ in range(num_blocks[num_steps])
])
# Define decoder levels
for n in reversed(range(num_steps)):
current_dim = dim * 2**n
next_dim = dim * 2**(n+1)
self.upsamples.append(Upsample(next_dim))
# Reduce channel layers to deal with skip connections
if n != 0:
self.reduce_channels.append(
nn.Conv2d(next_dim, current_dim, kernel_size=1, bias=bias)
)
decoder_dim = current_dim
else:
decoder_dim = next_dim
self.decoder_levels.append(
nn.Sequential(*[
TransformerBlock(
dim=decoder_dim,
num_heads=heads[n],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
flash_attention=flash_attention
) for _ in range(num_blocks[n])
])
)
# Final refinement and output
self.refinement = nn.Sequential(*[
TransformerBlock(
dim=decoder_dim,
num_heads=heads[0],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
flash_attention=flash_attention
) for _ in range(num_refinement_blocks)
])
self.dual_pixel_task = dual_pixel_task
if self.dual_pixel_task:
self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
def forward(self, x):
"""Forward pass of Restormer.
Processes input through encoder-decoder architecture with skip connections.
Args:
inp_img: Input image tensor of shape (B, C, H, W)
Returns:
Restored image tensor of shape (B, C, H, W)
"""
assert x.shape[-1] > 2 ** self.num_steps and x.shape[-2] > 2 ** self.num_steps, "Input dimensions should be larger than 2^number_of_step"
# Patch embedding
x = self.patch_embed(x)
skip_connections = []
# Encoding path
for idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)):
x = encoder(x)
skip_connections.append(x)
x = downsample(x)
# Latent space
x = self.latent(x)
# Decoding path
for idx in range(len(self.decoder_levels)):
x = self.upsamples[idx](x)
x = torch.concat([x, skip_connections[-(idx + 1)]], 1)
if idx < len(self.decoder_levels) - 1:
x = self.reduce_channels[idx](x)
x = self.decoder_levels[idx](x)
# Final refinement
x = self.refinement(x)
if self.dual_pixel_task:
x = x + self.skip_conv(skip_connections[0])
x = self.output(x)
else:
x = self.output(x)
return x
if __name__ == "__main__":
flash_att = True
test_model = Restormer(
inp_channels=2,
out_channels=2,
dim=16,
num_blocks=[1,1,1,1],
heads=[1,1,1,1],
num_refinement_blocks=2,
ffn_expansion_factor=1.5,
bias=False,
LayerNorm_type='WithBias',
dual_pixel_task=True,
flash_attention=flash_att
)
print(f'flash attention set to {flash_att}')
input_tensor = torch.randn(8, 2, 256, 256)
print(f"Input shape: {input_tensor.shape}")
output = test_model(input_tensor)
print(f"Output shape: {output.shape}")
print(f'printing final model')
from torchsummary import summary
summary(test_model, input_size=input_tensor)
``` |
@aylward @Nic-Ma @ericspod and @KumoLiu, if you all agree and there is no comments on extra modules to be added, I will implement the class as it is. For that, I will:
What aspects of this approach would you modify to fully align with MONAI's contribution standards? |
Hi @phisanti, thank you for sharing the comprehensive plan! I’d recommend dividing the implementation into several PRs to simplify the review process. Additionally, I highly suggest checking if there are existing blocks in MONAI that can be reused in your network, such as upsample, downsample, attention mechanisms, etc. https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/downsample.py Also, consider using Convolution, which could make your network support both 2D and 3D implementations seamlessly.
|
I have just done an in depth review of the Upsample/Downsample, SABlock and Transformer blocks present in MONAI. From what I can see, using the local version of the Upsample/Downsample classes is trivial. I think, somthing as:
should mirror the current behaviour in the current restormer. However using the MONAI classes for the SAB block would not work. The SABlock is a spatial attention mechanism based on Dosovitskiy paper. However, the restormer is a channel attention mechanism. See code below: class Attention(nn.Module):
"""Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention
by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
convolutions for local mixing before attention, achieving linear complexity vs quadratic
in vanilla attention."""
def __init__(self, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
super().__init__()
if flash_attention and not hasattr(F, 'scaled_dot_product_attention'):
raise ValueError("Flash attention not available")
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.flash_attention = flash_attention
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1,
padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self._attention_fn = self._get_attention_fn()
def _get_attention_fn(self):
if self.flash_attention:
return self._flash_attention
return self._normal_attention
def _flash_attention(self, q, k, v):
"""Flash attention implementation using scaled dot-product attention."""
scale = float(self.temperature.mean())
out = F.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
dropout_p=0.0,
is_causal=False
)
return out
def _normal_attention(self, q, k, v):
"""Attention matrix multiplication with depth-wise convolutions."""
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
return attn @ v
def forward(self, x):
"""Forward pass for MDTA attention.
1. Apply depth-wise convolutions to Q, K, V
2. Reshape Q, K, V for multi-head attention
3. Compute attention matrix using flash or normal attention
4. Reshape and project out attention output"""
b,c,h,w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q,k,v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
out = self._attention_fn(q, k, v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out My suggestion is to be implemented as separate class (CABlock, Chanel Attention Block). Then, in the native Transformer block from Monai, add two Attention mechanism and give the user the argument attention_type="spatial"/"channel". If so, the new class could be included in the Blocks segment. Let me know what do you think. |
A CABlock class seems appropriate, but it would be good if we could avoid
having an argument to toggle between two different variables / use-cases.
Would it be possible for CABlock and SABlock to have a member var that
defines its type: spatial or channel? Then the Transformer class would
query that member var of the block to determine if it is spatial- or
channel-based block and adjust accordingly (assuming transformer's logic
would change if self.attn points to CABlock)? Not certain if API
differences could be resolved - what do you think?
If so, then perhaps the Transformer class would default (as-is) to
self.attn being an SABlock, but after init, a user/caller could overwrite
self.attn with a CABlock?
s
…On Fri, Jan 10, 2025 at 10:44 AM Cano-Muniz, Santiago < ***@***.***> wrote:
Hi @KumoLiu <https://github.com/KumoLiu> and @aylward
<https://github.com/aylward>
I have just done an in depth review of the Upsample/Downsample, SABlock
and Transformer blocks present in MONAI. From what I can see, using the
local version of the Upsample/Downsample classes is trivial. I think,
somthing as:
UpSample(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=in_channels//2,
mode=UpsampleMode.PIXELSHUFFLE,
scale_factor=2,
bias=False,
apply_pad_pool=False
)
should mirror the current behaviour in the current restormer. However
using the MONAI classes for the SAB block would not work. The SABlock is a
spatial attention mechanism based on Dosovitskiy paper. However, the
restormer is a channel attention mechanism. See code below:
class Attention(nn.Module):
"""Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention by operating on feature channels instead of spatial dimensions. Incorporates depth-wise convolutions for local mixing before attention, achieving linear complexity vs quadratic in vanilla attention."""
def __init__(self, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
super().__init__()
if flash_attention and not hasattr(F, 'scaled_dot_product_attention'):
raise ValueError("Flash attention not available")
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.flash_attention = flash_attention
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1,
padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self._attention_fn = self._get_attention_fn()
def _get_attention_fn(self):
if self.flash_attention:
return self._flash_attention
return self._normal_attention
def _flash_attention(self, q, k, v):
"""Flash attention implementation using scaled dot-product attention."""
scale = float(self.temperature.mean())
out = F.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
dropout_p=0.0,
is_causal=False
)
return out
def _normal_attention(self, q, k, v):
"""Attention matrix multiplication with depth-wise convolutions."""
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
return attn @ v
def forward(self, x):
"""Forward pass for MDTA attention. 1. Apply depth-wise convolutions to Q, K, V 2. Reshape Q, K, V for multi-head attention 3. Compute attention matrix using flash or normal attention 4. Reshape and project out attention output"""
b,c,h,w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q,k,v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
out = self._attention_fn(q, k, v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
My suggestion is to be implemented as separate class (CABlock, Chanel
Attention Block). Then, in the native Transformer block from Monai, add two
Attention mechanism and give the user the argument
attention_type="spatial"/"channel". If so, the new class could be included
in the Blocks segment.
Let me know what do you think.
—
Reply to this email directly, view it on GitHub
<#8261 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AACEJLYDT6IMDYSL5OXW53D2J7TFDAVCNFSM6AAAAABTIP5HM2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKOBSHE4TKMJQHE>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Dear @aylward and @KumoLiu, thanks for your patience. Here is an overview of the progress I have made so far: The goal was to implement the
When writing the Next, I proceeded to implement the Regarding the Finally, for every new class and function, I wrote extensive unit tests trying to cover as many edge cases as I could come up with. However, any double-check is always welcome to raise the standards. You can check all the progress in my forked MOANI repo Please, let me know what you think. If you do a quick overview, then I can proceed with the following steps:
Once this is done, the implementation will be ready for review and integration into the main MONAI repository. |
Is your feature request related to a problem? Please describe.
I've noticed that MONAI currently lacks dedicated models for image denoising and restoration tasks. While MONAI provides excellent tools for medical image analysis, having specialized architectures for improving image quality would be valuable for preprocessing pipelines and enhancing low-quality medical images (microscopy, X-ray, scans...).
Describe the solution you'd like
I have implemented a well-documented version of the Restormer model (https://arxiv.org/abs/2111.09881) that could be contributed to MONAI. The implementation includes key components like:
Describe alternatives you've considered
The implementation is already structured in a modular way with clear separation of components. I'm willing to:
Additional context
The code is currently functional and tested. It supports both standard and dual-pixel tasks, with configurable parameters for network depth, attention heads, and feature dimensions. The implementation prioritizes efficiency through features like flash attention support while maintaining flexibility for different use cases.
The text was updated successfully, but these errors were encountered: