Skip to content

Commit

Permalink
Get back to SimpleDecoder with an attention layer
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasRieutord committed Sep 5, 2023
1 parent 2a6e2e8 commit de335ca
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions mmt/graphs/models/transformer_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def __init__(self, in_channels, out_channels, resize = 1):
self.attn3 = CrossChannelAttention(out_channels, out_channels)

def forward(self, x):
# x = self.attn1(x)
x = self.attn1(x)
x = self.attn2(x)
# x = self.attn3(x)
x = self.attn3(x)
return x

class SimpleDecoder(nn.Module):
Expand Down Expand Up @@ -205,13 +205,17 @@ def __init__(
self.decoder.add_module("pool", nn.MaxPool2d(resize))
for i in range(1, depth):
self.decoder.add_module(
"conv_{}".format(i),
f"conv_{i}",
nn.Conv2d(inc, nf, kernel_size=3, padding=1, bias=bias),
)
self.decoder.add_module(
"groupnorm_{}".format(i), nn.GroupNorm(num_groups, nf)
f"groupnorm_{i}", nn.GroupNorm(num_groups, nf)
)
self.decoder.add_module(f"relu_{i}", nn.ReLU(inplace=True))
self.decoder.add_module(
f"attn_{i}",
CrossChannelAttention(nf)
)
self.decoder.add_module("relu_{}".format(i), nn.ReLU(inplace=True))
inc = nf
self.decoder.add_module("conv_{}".format(depth), nn.Conv2d(inc, n_classes, 1))

Expand Down Expand Up @@ -485,7 +489,7 @@ def __init__(
resize=None,
pooling_factors=[3, 3, 3, 3, 3],
encoder=AttentionUNet,
decoder=AttentionDecoder,
decoder=SimpleDecoder,
decoder_atrou=True,
tlm_p=0,
bias=False,
Expand Down

0 comments on commit de335ca

Please sign in to comment.