diff --git a/src/models/pro_mod.py b/src/models/pro_mod.py index 404a224f..148cc25d 100644 --- a/src/models/pro_mod.py +++ b/src/models/pro_mod.py @@ -351,21 +351,21 @@ def forward_pro(self, data): self.edge_weight != 'binary') else None target_x = self.relu(target_x) - ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], - training=self.training) + # ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], + # training=self.training) # conv1 - xt = self.pro_conv1(target_x, ei_drp, ew) + xt = self.pro_conv1(target_x, ei, ew) xt = self.relu(xt) - ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], - training=self.training) + # ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], + # training=self.training) # conv2 - xt = self.pro_conv2(xt, ei_drp, ew) + xt = self.pro_conv2(xt, ei, ew) xt = self.relu(xt) - ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], - training=self.training) + # ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], + # training=self.training) # conv3 - xt = self.pro_conv3(xt, ei_drp, ew) + xt = self.pro_conv3(xt, ei, ew) xt = self.relu(xt) # flatten/pool