Skip to content

Commit

Permalink
Add MRI transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jan 3, 2025
1 parent 405e689 commit 7a27ee8
Show file tree
Hide file tree
Showing 2 changed files with 972 additions and 19 deletions.
88 changes: 69 additions & 19 deletions direct/nn/transformers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional

from omegaconf import MISSING

from direct.config.defaults import ModelConfig
from direct.constants import COMPLEX_SIZE
from direct.nn.transformers.uformer import AttentionTokenProjectionType, LeWinTransformerMLPTokenType


@dataclass
class UFormerModelConfig(ModelConfig):
in_channels: int = COMPLEX_SIZE
out_channels: Optional[int] = None
patch_size: int = 256
embedding_dim: int = 32
encoder_depths: Tuple[int, ...] = (2, 2, 2, 2)
encoder_num_heads: Tuple[int, ...] = (1, 2, 4, 8)
encoder_depths: tuple[int, ...] = (2, 2, 2, 2)
encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8)
bottleneck_depth: int = 2
bottleneck_num_heads: int = 16
win_size: int = 8
Expand All @@ -34,31 +39,34 @@ class UFormerModelConfig(ModelConfig):


@dataclass
class VisionTransformer2DConfig(ModelConfig):
average_img_size: int | tuple[int, int] = MISSING
patch_size: int | tuple[int, int] = 16
embedding_dim: int = 64
depth: int = 8
num_heads: int = (9,)
class ImageDomainMRIUFormerConfig(ModelConfig):
patch_size: int = 256
embedding_dim: int = 32
encoder_depths: tuple[int, ...] = (2, 2, 2, 2)
encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8)
bottleneck_depth: int = 2
bottleneck_num_heads: int = 16
win_size: int = 8
mlp_ratio: float = 4.0
qkv_bias: bool = False
qk_scale: float = None
qkv_bias: bool = True
qk_scale: Optional[float] = None
drop_rate: float = 0.0
attn_drop_rate: float = 0.0
dropout_path_rate: float = 0.0
use_gpsa: bool = True
locality_strength: float = 1.0
use_pos_embedding: bool = True
drop_path_rate: float = 0.1
patch_norm: bool = True
token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR
token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.LEFF
shift_flag: bool = True
modulator: bool = False
cross_modulator: bool = False
normalized: bool = True


@dataclass
class VisionTransformer3DConfig(ModelConfig):
average_img_size: int | tuple[int, int, int] = MISSING
patch_size: int | tuple[int, int, int] = 16
class MRIViTConfig(ModelConfig):
embedding_dim: int = 64
depth: int = 8
num_heads: int = (9,)
num_heads: int = 9
mlp_ratio: float = 4.0
qkv_bias: bool = False
qk_scale: float = None
Expand All @@ -69,3 +77,45 @@ class VisionTransformer3DConfig(ModelConfig):
locality_strength: float = 1.0
use_pos_embedding: bool = True
normalized: bool = True


@dataclass
class VisionTransformer2DConfig(MRIViTConfig):
in_channels: int = COMPLEX_SIZE
out_channels: Optional[int] = None
average_img_size: tuple[int, int] = MISSING
patch_size: tuple[int, int] = (16, 16)


@dataclass
class VisionTransformer3DConfig(MRIViTConfig):
in_channels: int = COMPLEX_SIZE
out_channels: Optional[int] = None
average_img_size: tuple[int, int, int] = MISSING
patch_size: tuple[int, int, int] = (16, 16, 16)


@dataclass
class ImageDomainMRIViT2DConfig(MRIViTConfig):
average_size: tuple[int, int] = (320, 320)
patch_size: tuple[int, int] = (16, 16)


@dataclass
class ImageDomainMRIViT3DConfig(MRIViTConfig):
average_size: tuple[int, int] = (320, 320, 320)
patch_size: tuple[int, int] = (16, 16, 16)


@dataclass
class KSpaceDomainMRIViT2DConfig(MRIViTConfig):
average_size: tuple[int, int] = (320, 320)
patch_size: tuple[int, int] = (16, 16)
compute_per_coil: bool = True


@dataclass
class KSpaceDomainMRIViT3DConfig(MRIViTConfig):
average_size: tuple[int, int] = (320, 320, 320)
patch_size: tuple[int, int] = (16, 16, 16)
compute_per_coil: bool = True
Loading

0 comments on commit 7a27ee8

Please sign in to comment.