Skip to content

Commit

Permalink
Merge branch 'dev' into 4980-get-wsi-at-mpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ericspod authored Aug 7, 2024
2 parents 5461801 + 49a1e34 commit e8c1544
Show file tree
Hide file tree
Showing 12 changed files with 286 additions and 71 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ jobs:
name: Install itk pre-release (Linux only)
run: |
python -m pip install --pre -U itk
find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
- name: Install the dependencies
run: |
python -m pip install --user --upgrade pip wheel
Expand Down
21 changes: 16 additions & 5 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,15 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str:


def _download_from_ngc(
download_path: Path, filename: str, version: str, remove_prefix: str | None, progress: bool
download_path: Path,
filename: str,
version: str,
prefix: str = "monai_",
remove_prefix: str | None = "monai_",
progress: bool = True,
) -> None:
# ensure prefix is contained
filename = _add_ngc_prefix(filename)
filename = _add_ngc_prefix(filename, prefix=prefix)
url = _get_ngc_bundle_url(model_name=filename, version=version)
filepath = download_path / f"{filename}_v{version}.zip"
if remove_prefix:
Expand All @@ -231,10 +236,16 @@ def _download_from_ngc(


def _download_from_ngc_private(
download_path: Path, filename: str, version: str, remove_prefix: str | None, repo: str, headers: dict | None = None
download_path: Path,
filename: str,
version: str,
repo: str,
prefix: str = "monai_",
remove_prefix: str | None = "monai_",
headers: dict | None = None,
) -> None:
# ensure prefix is contained
filename = _add_ngc_prefix(filename)
filename = _add_ngc_prefix(filename, prefix=prefix)
request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
if has_requests:
headers = {} if headers is None else headers
Expand Down Expand Up @@ -491,7 +502,7 @@ def download(
url: url to download the data. If not `None`, data will be downloaded directly
and `source` will not be checked.
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
remove_prefix: This argument is used when `source` is "ngc" or "ngc_private". Currently, all ngc bundles
have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
maintain the consistency between these two sources, remove prefix is necessary.
Therefore, if specified, downloaded folder name will remove the prefix.
Expand Down
73 changes: 53 additions & 20 deletions monai/networks/blocks/crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.nn as nn

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import
from monai.utils import optional_import, pytorch_after

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")

Expand All @@ -44,6 +44,7 @@ def __init__(
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
"""
Args:
Expand All @@ -55,13 +56,16 @@ def __init__(
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
causal: whether to use causal attention.
sequence_length: if causal is True, it is necessary to specify the sequence length.
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
causal (bool, optional): whether to use causal attention.
sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only
"decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional
parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use Pytorch's inbuilt
flash attention for a memory efficient attention mechanism (see
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""

super().__init__()
Expand All @@ -81,6 +85,20 @@ def __init__(
if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
raise ValueError(
"use_flash_attention is only supported for PyTorch versions >= 2.0."
"Upgrade your PyTorch or set the flag to False."
)
if use_flash_attention and save_attn:
raise ValueError(
"save_attn has been set to True, but use_flash_attention is also set"
"to True. save_attn can only be used if use_flash_attention is False"
)

if use_flash_attention and rel_pos_embedding is not None:
raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.context_input_size = context_input_size if context_input_size else hidden_size
Expand All @@ -94,13 +112,15 @@ def __init__(
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.dropout_rate = dropout_rate

self.scale = self.head_dim**-0.5
self.save_attn = save_attn
self.attention_dtype = attention_dtype

self.causal = causal
self.sequence_length = sequence_length
self.use_flash_attention = use_flash_attention

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
Expand Down Expand Up @@ -142,26 +162,39 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
q = q.to(self.attention_dtype)
k = k.to(self.attention_dtype)

q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) #
k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
if self.use_flash_attention:
x = torch.nn.functional.scaled_dot_product_attention(
query=q.transpose(1, 2),
key=k.transpose(1, 2),
value=v.transpose(1, 2),
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
).transpose(
1, 2
) # Back to (b, nh, t, hs)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
# apply relative positional embedding if defined
if self.rel_positional_embedding is not None:
att_mat = self.rel_positional_embedding(x, att_mat, q)

if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
Expand Down
20 changes: 16 additions & 4 deletions monai/networks/blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@

from __future__ import annotations

from typing import Union

import torch.nn as nn

from monai.networks.layers import get_act_layer
from monai.networks.layers.factories import split_args
from monai.utils import look_up_option

SUPPORTED_DROPOUT_MODE = {"vit", "swin"}
SUPPORTED_DROPOUT_MODE = {"vit", "swin", "vista3d"}


class MLPBlock(nn.Module):
Expand All @@ -39,7 +42,7 @@ def __init__(
https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87
"swin" corresponds to one instance as implemented in
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23
"vista3d" mode does not use dropout.
"""

Expand All @@ -48,15 +51,24 @@ def __init__(
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
mlp_dim = mlp_dim or hidden_size
self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
act_name, _ = split_args(act)
self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
self.linear2 = nn.Linear(mlp_dim, hidden_size)
self.fn = get_act_layer(act)
self.drop1 = nn.Dropout(dropout_rate)
# Use Union[nn.Dropout, nn.Identity] for type annotations
self.drop1: Union[nn.Dropout, nn.Identity]
self.drop2: Union[nn.Dropout, nn.Identity]

dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE)
if dropout_opt == "vit":
self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = nn.Dropout(dropout_rate)
elif dropout_opt == "swin":
self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = self.drop1
elif dropout_opt == "vista3d":
self.drop1 = nn.Identity()
self.drop2 = nn.Identity()
else:
raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}")

Expand Down
58 changes: 45 additions & 13 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import
from monai.utils import optional_import, pytorch_after

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")

Expand All @@ -42,6 +43,7 @@ def __init__(
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
"""
Args:
Expand All @@ -59,6 +61,9 @@ def __init__(
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use Pytorch's inbuilt
flash attention for a memory efficient attention mechanism (see
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""

Expand All @@ -82,6 +87,20 @@ def __init__(
if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
raise ValueError(
"use_flash_attention is only supported for PyTorch versions >= 2.0."
"Upgrade your PyTorch or set the flag to False."
)
if use_flash_attention and save_attn:
raise ValueError(
"save_attn has been set to True, but use_flash_attention is also set"
"to True. save_attn can only be used if use_flash_attention is False."
)

if use_flash_attention and rel_pos_embedding is not None:
raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
Expand All @@ -91,12 +110,14 @@ def __init__(
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.dropout_rate = dropout_rate
self.scale = self.dim_head**-0.5
self.save_attn = save_attn
self.att_mat = torch.Tensor()
self.attention_dtype = attention_dtype
self.causal = causal
self.sequence_length = sequence_length
self.use_flash_attention = use_flash_attention

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
Expand Down Expand Up @@ -130,23 +151,34 @@ def forward(self, x):
q = q.to(self.attention_dtype)
k = k.to(self.attention_dtype)

att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
if self.use_flash_attention:
x = F.scaled_dot_product_attention(
query=q.transpose(1, 2),
key=k.transpose(1, 2),
value=v.transpose(1, 2),
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
).transpose(1, 2)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
# apply relative positional embedding if defined
if self.rel_positional_embedding is not None:
att_mat = self.rel_positional_embedding(x, att_mat, q)

if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))
if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
Expand Down
8 changes: 7 additions & 1 deletion monai/networks/blocks/spatialattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SpatialAttentionBlock(nn.Module):
num_channels: number of input channels. Must be divisible by num_head_channels.
num_head_channels: number of channels per head.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

Expand All @@ -44,6 +45,7 @@ def __init__(
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
super().__init__()

Expand All @@ -54,7 +56,11 @@ def __init__(
raise ValueError("num_channels must be divisible by num_head_channels")
num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
self.attn = SABlock(
hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype
hidden_size=num_channels,
num_heads=num_heads,
qkv_bias=True,
attention_dtype=attention_dtype,
use_flash_attention=use_flash_attention,
)

def forward(self, x: torch.Tensor):
Expand Down
Loading

0 comments on commit e8c1544

Please sign in to comment.