diff --git a/mmt/graphs/models/transformer_embedding.py b/mmt/graphs/models/transformer_embedding.py index 9fa6139..b342e35 100644 --- a/mmt/graphs/models/transformer_embedding.py +++ b/mmt/graphs/models/transformer_embedding.py @@ -167,9 +167,9 @@ def __init__(self, in_channels, out_channels, resize = 1): self.attn3 = CrossChannelAttention(out_channels, out_channels) def forward(self, x): - # x = self.attn1(x) + x = self.attn1(x) x = self.attn2(x) - # x = self.attn3(x) + x = self.attn3(x) return x class SimpleDecoder(nn.Module): @@ -205,13 +205,17 @@ def __init__( self.decoder.add_module("pool", nn.MaxPool2d(resize)) for i in range(1, depth): self.decoder.add_module( - "conv_{}".format(i), + f"conv_{i}", nn.Conv2d(inc, nf, kernel_size=3, padding=1, bias=bias), ) self.decoder.add_module( - "groupnorm_{}".format(i), nn.GroupNorm(num_groups, nf) + f"groupnorm_{i}", nn.GroupNorm(num_groups, nf) + ) + self.decoder.add_module(f"relu_{i}", nn.ReLU(inplace=True)) + self.decoder.add_module( + f"attn_{i}", + CrossChannelAttention(nf) ) - self.decoder.add_module("relu_{}".format(i), nn.ReLU(inplace=True)) inc = nf self.decoder.add_module("conv_{}".format(depth), nn.Conv2d(inc, n_classes, 1)) @@ -485,7 +489,7 @@ def __init__( resize=None, pooling_factors=[3, 3, 3, 3, 3], encoder=AttentionUNet, - decoder=AttentionDecoder, + decoder=SimpleDecoder, decoder_atrou=True, tlm_p=0, bias=False,