Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jan 3, 2025
1 parent bc7605e commit 405e689
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 34 deletions.
55 changes: 22 additions & 33 deletions direct/nn/transformers/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ class GPSA(nn.Module):
Parameters
----------
dimensionality : VisionTransformerDimensionality
The dimensionality of the input data.
dim : int
Dimensionality of the input embeddings.
num_heads : int
Expand Down Expand Up @@ -255,7 +253,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class GPSA2D(GPSA):
"""Gated Positional Self-Attention module for Vision Transformer.
"""Gated Positional Self-Attention module for Vision Transformer (2D variant).
Parameters
----------
Expand Down Expand Up @@ -291,7 +289,7 @@ def __init__(
use_local_init: bool = True,
grid_size=None,
) -> None:
"""Inits :class:`GPSA`.
"""Inits :class:`GPSA2D`.
Parameters
----------
Expand Down Expand Up @@ -327,47 +325,40 @@ def __init__(
)

def local_init(self, locality_strength: Optional[float] = 1.0) -> None:
"""Initializes the parameters for a locally connected attention mechanism.
"""Initializes the positional projection weights with locality distance.
Parameters
----------
locality_strength : float, optional
A scalar multiplier for the locality distance. Default: 1.0.
Returns
-------
None
Determines how focused the attention is around its center.
"""
self.v.weight.data.copy_(torch.eye(self.dim))
locality_distance = 1 # max(1,1/locality_strength**.5)

kernel_size = int(self.num_heads**0.5)
center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2

# compute the positional projection weights with locality distance
for h1 in range(kernel_size):
for h2 in range(kernel_size):
position = h1 + kernel_size * h2
self.pos_proj.weight.data[position, 2] = -1
self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance
self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance
self.pos_proj.weight.data[position, 1] = 2 * (h1 - center)
self.pos_proj.weight.data[position, 0] = 2 * (h2 - center)

self.pos_proj.weight.data *= locality_strength

def get_rel_indices(self) -> None:
"""Generates relative positional indices for each patch in the input.
Returns
-------
None
"""
"""Get relative indices for 2D grid of patches."""
H, W = self.current_grid_size
N = H * W

rel_indices = torch.zeros(1, N, N, 3)

indx = torch.arange(W).view(1, -1) - torch.arange(W).view(-1, 1)
indx = indx.repeat(H, H)
indy = torch.arange(H).view(1, -1) - torch.arange(H).view(-1, 1)
indy = indy.repeat_interleave(W, dim=0).repeat_interleave(W, dim=1)
indd = indx**2 + indy**2

rel_indices[:, :, :, 2] = indd.unsqueeze(0)
rel_indices[:, :, :, 1] = indy.unsqueeze(0)
rel_indices[:, :, :, 0] = indx.unsqueeze(0)
Expand Down Expand Up @@ -425,8 +416,14 @@ def __init__(
)

def local_init(self, locality_strength: Optional[float] = 1.0) -> None:
"""Initializes the positional projection weights with locality distance.
Parameters
----------
locality_strength : float, optional
Determines how focused the attention is around its center.
"""
self.v.weight.data.copy_(torch.eye(self.dim))
locality_distance = 1

kernel_size = int(self.num_heads ** (1 / 3))
center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2
Expand All @@ -436,11 +433,13 @@ def local_init(self, locality_strength: Optional[float] = 1.0) -> None:
for h3 in range(kernel_size):
position = h1 + kernel_size * (h2 + kernel_size * h3)
self.pos_proj.weight.data[position, 2] = -1
self.pos_proj.weight.data[position, 1] = 2 * (h2 - center) * locality_distance
self.pos_proj.weight.data[position, 0] = 2 * (h3 - center) * locality_distance
self.pos_proj.weight.data[position, 1] = 2 * (h2 - center)
self.pos_proj.weight.data[position, 0] = 2 * (h3 - center)

self.pos_proj.weight.data *= locality_strength

def get_rel_indices(self) -> torch.Tensor:
"""Get relative indices for 3D grid of patches."""
D, H, W = self.current_grid_size
N = D * H * W
rel_indices = torch.zeros(1, N, N, 3)
Expand Down Expand Up @@ -480,10 +479,6 @@ class MHSA(nn.Module):
Dropout rate for the attention weights. Default is 0.
proj_drop : float
Dropout rate for the output of the module. Default is 0.
grid_size : tuple[int, int] or None
If not None, the module is designed to work with a grid of
patches. grid_size is a tuple of the form (H, W) where H and W are the number of patches in
the vertical and horizontal directions respectively. Default is None.
"""

def __init__(
Expand All @@ -494,7 +489,6 @@ def __init__(
qk_scale: float = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
grid_size: tuple[int, int] = None,
) -> None:
"""Inits :class:`MHSA`.
Expand All @@ -513,10 +507,6 @@ def __init__(
Dropout rate for the attention weights. Default is 0.
proj_drop : float
Dropout rate for the output of the module. Default is 0.
grid_size : tuple[int, int] or None
If not None, the module is designed to work with a grid of
patches. grid_size is a tuple of the form (H, W) where H and W are the number of patches in
the vertical and horizontal directions respectively. Default is None.
"""
super().__init__()
self.num_heads = num_heads
Expand All @@ -528,7 +518,6 @@ def __init__(
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.apply(init_weights)
self.current_grid_size = grid_size

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of :class:`MHSA`.
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_nn/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import torch

from direct.nn.transformers.uformer import UFormerModel, AttentionTokenProjectionType, LeWinTransformerMLPTokenType
from direct.nn.transformers.uformer import AttentionTokenProjectionType, LeWinTransformerMLPTokenType, UFormerModel
from direct.nn.transformers.vit import VisionTransformer2D, VisionTransformer3D


Expand Down

0 comments on commit 405e689

Please sign in to comment.