@@ -38,20 +38,20 @@ def __init__(
38
38
39
39
# First GATv2Conv layer
40
40
self .conv_first = GATv2Conv ((- 1 , - 1 ), hidden_channels , heads = heads , add_self_loops = False )
41
- self .lin_first = Linear (- 1 , hidden_channels * heads )
41
+ # self.lin_first = Linear(-1, hidden_channels * heads)
42
42
43
43
# Middle GATv2Conv layers
44
44
self .num_mid_layers = num_mid_layers
45
45
if num_mid_layers > 0 :
46
46
self .conv_mid_layers = torch .nn .ModuleList ()
47
- self .lin_mid_layers = torch .nn .ModuleList ()
47
+ # self.lin_mid_layers = torch.nn.ModuleList()
48
48
for _ in range (num_mid_layers ):
49
49
self .conv_mid_layers .append (GATv2Conv ((- 1 , - 1 ), hidden_channels , heads = heads , add_self_loops = False ))
50
- self .lin_mid_layers .append (Linear (- 1 , hidden_channels * heads ))
50
+ # self.lin_mid_layers.append(Linear(-1, hidden_channels * heads))
51
51
52
52
# Last GATv2Conv layer
53
53
self .conv_last = GATv2Conv ((- 1 , - 1 ), out_channels , heads = heads , add_self_loops = False )
54
- self .lin_last = Linear (- 1 , out_channels * heads )
54
+ # self.lin_last = Linear(-1, out_channels * heads)
55
55
56
56
def forward (self , x : Tensor , edge_index : Tensor ) -> Tensor :
57
57
"""
@@ -70,19 +70,19 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
70
70
x = self .tx_embedding (((x .sum (1 ) * is_one_dim ).int ())) * is_one_dim + self .lin0 (x .float ()) * (1 - is_one_dim )
71
71
# First layer
72
72
x = x .relu ()
73
- x = self .conv_first (x , edge_index ) + self .lin_first (x )
73
+ x = self .conv_first (x , edge_index ) # + self.lin_first(x)
74
74
x = x .relu ()
75
75
76
76
# Middle layers
77
77
if self .num_mid_layers > 0 :
78
78
for i in range (self .num_mid_layers ):
79
79
conv_mid = self .conv_mid_layers [i ]
80
- lin_mid = self .lin_mid_layers [i ]
81
- x = conv_mid (x , edge_index ) + lin_mid (x )
80
+ # lin_mid = self.lin_mid_layers[i]
81
+ x = conv_mid (x , edge_index ) # + lin_mid(x)
82
82
x = x .relu ()
83
83
84
84
# Last layer
85
- x = self .conv_last (x , edge_index ) + self .lin_last (x )
85
+ x = self .conv_last (x , edge_index ) # + self.lin_last(x)
86
86
87
87
return x
88
88
0 commit comments