Skip to content

Commit

Permalink
Add experimental attention modules (oom)
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasRieutord committed Sep 7, 2023
1 parent 58dcb29 commit 6e38b2b
Showing 1 changed file with 77 additions and 6 deletions.
83 changes: 77 additions & 6 deletions mmt/graphs/models/transformer_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,24 @@ def check_shapes(self, x = None):
attention = nn.functional.scaled_dot_product_attention(q, k, v).permute(0,3,2,1)
print(f"attention = {attention.shape}")


class AttentionEncoder(nn.Module):
def __init__(self, in_channels, out_channels, resize = 1):
super().__init__()
print(f"Init {self.__class__.__name__} with in_channels={in_channels}, out_channels={out_channels}, resize = {resize}")
if resize is None:
resize = 1

self.attn1 = CrossChannelAttention(in_channels, in_channels)
self.attn2 = CrossResolutionAttention(in_channels=in_channels, out_channels=out_channels, resize=resize)
self.attn3 = CrossChannelAttention(out_channels, out_channels)

def forward(self, x):
x = self.attn1(x)
x = self.attn2(x)
x = self.attn3(x)
return x

class AttentionDecoder(nn.Module):
def __init__(self, in_channels, out_channels, resize = 1):
super().__init__()
Expand Down Expand Up @@ -212,17 +230,54 @@ def __init__(
f"groupnorm_{i}", nn.GroupNorm(num_groups, nf)
)
self.decoder.add_module(f"relu_{i}", nn.ReLU(inplace=True))
self.decoder.add_module(
f"attn_{i}",
CrossChannelAttention(nf)
)
inc = nf
self.decoder.add_module("conv_{}".format(depth), nn.Conv2d(inc, n_classes, 1))
self.decoder.add_module(f"attn_{depth}f", CrossChannelAttention(inc, 10))
self.decoder.add_module(f"conv_{depth}", nn.Conv2d(inc, n_classes, 1))

def forward(self, x):
x = self.decoder(x)
return x

class SimpleAttentionDecoder(nn.Module):
"""
This class represents the tail of ResNet. It performs a global pooling and maps the output to the
correct class by using a fully connected layer.
"""

def __init__(
self,
in_features,
n_classes,
depth=1,
num_groups=4,
nf=52,
resize=None,
atrou=True,
bias=False,
):
super().__init__()
self.resize = resize

if num_groups is None:
num_groups = nf

if resize is not None:
if atrou:
self.pool = AtrouMMU(in_features, scale_factor=resize, bias=bias)
else:
self.pool = nn.MaxPool2d(resize)

self.attn1 = CrossChannelAttention(in_features, 10)
self.conv1 = nn.Conv2d(in_features, n_classes, 1)

def forward(self, x):
if self.resize is not None:
x = self.pool(x)

# x = self.attn1(x)
x = self.conv1(x)
return x


class AttentionUNet(nn.Module):
def __init__(
Expand Down Expand Up @@ -295,6 +350,22 @@ def __init__(
),
CrossChannelAttention(number_feature_map, 10)
)
# self.down3 = Down(
# number_feature_map,
# number_feature_map,
# mode=down_mode,
# num_groups=num_groups,
# factor=pooling_factors[2],
# bias=bias,
# )
# self.down4 = Down(
# number_feature_map,
# number_feature_map,
# mode=down_mode,
# num_groups=num_groups,
# factor=pooling_factors[3],
# bias=bias,
# )
self.down5 = Down(
number_feature_map,
number_feature_map,
Expand Down Expand Up @@ -489,7 +560,7 @@ def __init__(
resize=None,
pooling_factors=[3, 3, 3, 3, 3],
encoder=AttentionUNet,
decoder=SimpleDecoder,
decoder=SimpleAttentionDecoder,
decoder_atrou=True,
tlm_p=0,
bias=False,
Expand Down

0 comments on commit 6e38b2b

Please sign in to comment.