-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AttributeError: 'GCNConv' object has no attribute '__check_input__' #206
Comments
I have faced the same problem, have you solved it? |
No, currently I'm running explainable models in Pytorch Geometric, I would like to test this library but i cannot run the examples. |
I have faced the same problem above, any updates? |
Hello again, do we have any news related to this issue? |
I experienced the same error: This is due to a naming convention change in torch_geometric v2.3.0, where To fix it: Update the attribute names in your code. For instance, in ~/anaconda3/envs/dive_into_graphs/lib/python3.9/site-packages/dig/xgraph/models/models.py:362, change Alternatively, downgrade torch_geometric to |
I just had the same problem here. It is solved by downgrading torch_geometric to v2.2.0 |
I'm trying to run the examples for explainable GNN ( xgnn), but when loading the model exported from dig.xgraph.models get the error above.
I'm currently running the experiment in conda environment with:
pytorch 2.0.0,
python 3.9,
CUDA 11.7
torch_geometric 2.3.0
Here are the traceback:
`---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[8], line 10
7 if torch.isnan(data.y[0].squeeze()):
8 continue
---> 10 logits = model(data.x, data.edge_index)
11 prediction = logits[node_idx].argmax(-1).item()
13 _, explanation_results, related_preds = explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes)
File ~/anaconda3/envs/dive_into_graphs/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/dive_into_graphs/lib/python3.9/site-packages/dig/xgraph/models/models.py:164, in GCN_2l.forward(self, *args, **kwargs)
158 """
159 :param Required[data]: Batch - input data
160 :return:
161 """
162 x, edge_index, batch = self.arguments_read(*args, **kwargs)
--> 164 post_conv = self.relu1(self.conv1(x, edge_index))
165 for conv, relu in zip(self.convs, self.relus):
166 post_conv = relu(conv(post_conv, edge_index))
File ~/anaconda3/envs/dive_into_graphs/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/dive_into_graphs/lib/python3.9/site-packages/dig/xgraph/models/models.py:350, in GCNConv.forward(self, x, edge_index, edge_weight)
347 x = torch.matmul(x, self.weight)
349 # propagate_type: (x: Tensor, edge_weight: OptTensor)
--> 350 out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
351 size=None)
353 if self.bias is not None:
354 out += self.bias
File ~/anaconda3/envs/dive_into_graphs/lib/python3.9/site-packages/dig/xgraph/models/models.py:362, in GCNConv.propagate(self, edge_index, size, **kwargs)
361 def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
--> 362 size = self.check_input(edge_index, size)
364 # Run "fused" message and aggregation (if applicable).
365 if (isinstance(edge_index, SparseTensor) and self.fuse
366 and not self._explain):
File ~/anaconda3/envs/dive_into_graphs/lib/python3.9/site-packages/torch/nn/modules/module.py:1614, in Module.getattr(self, name)
1612 if name in modules:
1613 return modules[name]
-> 1614 raise AttributeError("'{}' object has no attribute '{}'".format(
1615 type(self).name, name))
AttributeError: 'GCNConv' object has no attribute 'check_input'`
Also tried with GIN_2l but got the same result as GCN_2l
The part of the code that is failing is this:
`# --- Create data collector and explanation processor ---
from dig.xgraph.evaluation import XCollector
x_collector = XCollector()
index = -1
node_indices = torch.where(dataset[0].test_mask * dataset[0].y != 0)[0].tolist()
data = dataset[0]
from dig.xgraph.method.subgraphx import PlotUtils
from dig.xgraph.method.subgraphx import find_closest_node_result
Visualization
max_nodes = 5
node_idx = node_indices[20]
print(f'explain graph node {node_idx}')
data.to(device)
logits = model(data.x, data.edge_index)
prediction = logits[node_idx].argmax(-1).item()
_, explanation_results, related_preds = explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes)
explanation_results = explanation_results[prediction]
explanation_results = explainer.read_from_MCTSInfo_list(explanation_results)
plotutils = PlotUtils(dataset_name='ba_shapes', is_show=True)
explainer.visualization(explanation_results,
max_nodes=max_nodes,
plot_utils=plotutils,
y=data.y)`
The code has been obtained from DIG repository: https://github.com/divelab/DIG/blob/dig-stable/examples/xgraph/subgraphx.ipynb
The text was updated successfully, but these errors were encountered: