@@ -35,40 +35,3 @@ def forward(self, x):
35
35
out = self .conv3 (out )
36
36
return out
37
37
38
-
39
- class encoder2 (nn .Module ):
40
- def __init__ (self , vgg ):
41
- super (encoder2 , self ).__init__ ()
42
- self .conv1 = nn .Conv2d (3 , 3 , 1 , 1 , 0 )
43
- self .conv1 .weight = torch .nn .Parameter (vgg .get (0 ).weight .float ())
44
- self .conv1 .bias = torch .nn .Parameter (vgg .get (0 ).bias .float ())
45
- self .rp1 = nn .ReflectionPad2d ((1 , 1 , 1 , 1 ))
46
- self .conv2 = nn .Conv2d (3 , 64 , 3 , 1 , 0 )
47
- self .conv2 .weight = torch .nn .Parameter (vgg .get (2 ).weight .float ())
48
- self .conv2 .bias = torch .nn .Parameter (vgg .get (2 ).bias .float ())
49
- self .relu2 = nn .ReLU (inplace = True )
50
- self .rp3 = nn .ReflectionPad2d ((1 , 1 , 1 , 1 ))
51
- self .conv3 = nn .Conv2d (64 , 64 , 3 , 1 , 0 )
52
- self .conv3 .weight = torch .nn .Parameter (vgg .get (5 ).weight .float ())
53
- self .conv3 .bias = torch .nn .Parameter (vgg .get (5 ).bias .float ())
54
- self .relu3 = nn .ReLU (inplace = True )
55
- self .mp = nn .MaxPool2d (kernel_size = 2 , stride = 2 , return_indices = True )
56
- self .rp4 = nn .ReflectionPad2d ((1 , 1 , 1 , 1 ))
57
- self .conv4 = nn .Conv2d (64 , 128 , 3 , 1 , 0 )
58
- self .conv4 .weight = torch .nn .Parameter (vgg .get (9 ).weight .float ())
59
- self .conv4 .bias = torch .nn .Parameter (vgg .get (9 ).bias .float ())
60
- self .relu4 = nn .ReLU (inplace = True )
61
-
62
- def forward (self , x ):
63
- out = self .conv1 (x )
64
- out = self .rp1 (out )
65
- out = self .conv2 (out )
66
- out = self .relu2 (out )
67
- out = self .rp3 (out )
68
- out = self .conv3 (out )
69
- pool = self .relu3 (out )
70
- out , pool_idx = self .mp (pool )
71
- out = self .rp4 (out )
72
- out = self .conv4 (out )
73
- out = self .relu4 (out )
74
- return out
0 commit comments