We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi, I am trying to use pre-trained model on ESOL dataset.
from tqdm import tqdm import dgl from dgllife.data import ESOL from dgllife.model import load_pretrained from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, AttentiveFPAtomFeaturizer, CanonicalBondFeaturizer, AttentiveFPBondFeaturizer dataset_canonical = ESOL(smiles_to_bigraph, CanonicalAtomFeaturizer(),CanonicalBondFeaturizer()) model = load_pretrained('Weave_canonical_ESOL') # Pretrained model loaded model.eval() for smiles, g, label in tqdm(dataset_canonical): nfeats = g.ndata['h'] efeats = g.edata['e'] label_pred = model(g, nfeats, efeats) print(label_pred) print(label)
This throws the following error
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) /tmp/ipykernel_184688/2242391364.py in <module> 7 nfeats = g.ndata['h'] 8 efeats = g.edata['e'] ----> 9 label_pred = model(g, nfeats, efeats) 10 print(label_pred) 11 print(label) ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1101 or _global_forward_hooks or _global_forward_pre_hooks): -> 1102 return forward_call(*input, **kwargs) 1103 # Do not call functions when jit is used 1104 full_backward_hooks, non_full_backward_hooks = [], [] ~/miniconda3/envs/dgl/lib/python3.9/site-packages/dgllife/model/model_zoo/weave_predictor.py in forward(self, g, node_feats, edge_feats) 103 Prediction for the graphs in the batch. G for the number of graphs. 104 """ --> 105 node_feats = self.gnn(g, node_feats, edge_feats, node_only=True) 106 node_feats = self.node_to_graph(node_feats) 107 g_feats = self.readout(g, node_feats) ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1101 or _global_forward_hooks or _global_forward_pre_hooks): -> 1102 return forward_call(*input, **kwargs) 1103 # Do not call functions when jit is used 1104 full_backward_hooks, non_full_backward_hooks = [], [] ~/miniconda3/envs/dgl/lib/python3.9/site-packages/dgllife/model/gnn/weave.py in forward(self, g, node_feats, edge_feats, node_only) 208 """ 209 for i in range(len(self.gnn_layers) - 1): --> 210 node_feats, edge_feats = self.gnn_layers[i](g, node_feats, edge_feats) 211 return self.gnn_layers[-1](g, node_feats, edge_feats, node_only) ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1101 or _global_forward_hooks or _global_forward_pre_hooks): -> 1102 return forward_call(*input, **kwargs) 1103 # Do not call functions when jit is used 1104 full_backward_hooks, non_full_backward_hooks = [], [] ~/miniconda3/envs/dgl/lib/python3.9/site-packages/dgllife/model/gnn/weave.py in forward(self, g, node_feats, edge_feats, node_only) 107 # Update node features 108 node_node_feats = self.activation(self.node_to_node(node_feats)) --> 109 g.edata['e2n'] = self.activation(self.edge_to_node(edge_feats)) 110 g.update_all(fn.copy_edge('e2n', 'm'), fn.sum('m', 'e2n')) 111 edge_node_feats = g.ndata.pop('e2n') ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1101 or _global_forward_hooks or _global_forward_pre_hooks): -> 1102 return forward_call(*input, **kwargs) 1103 # Do not call functions when jit is used 1104 full_backward_hooks, non_full_backward_hooks = [], [] ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/modules/linear.py in forward(self, input) 101 102 def forward(self, input: Tensor) -> Tensor: --> 103 return F.linear(input, self.weight, self.bias) 104 105 def extra_repr(self) -> str: ~/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/nn/functional.py in linear(input, weight, bias) 1846 if has_torch_function_variadic(input, weight, bias): 1847 return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias) -> 1848 return torch._C._nn.linear(input, weight, bias) 1849 1850 RuntimeError: mat1 and mat2 shapes cannot be multiplied (68x12 and 13x256)
I check the shape of graph and the construction of WeavePredictor, find they are not match
>>>smiles, g, label = dataset_canonical[0] >>>print(g.edata['e'].shape) torch.Size([68, 12]) >>>print(model) WeavePredictor( (gnn): WeaveGNN( (gnn_layers): ModuleList( (0): WeaveLayer( (node_to_node): Linear(in_features=74, out_features=256, bias=True) (edge_to_node): Linear(in_features=13, out_features=256, bias=True) (update_node): Linear(in_features=512, out_features=256, bias=True) (left_node_to_edge): Linear(in_features=74, out_features=256, bias=True) (right_node_to_edge): Linear(in_features=74, out_features=256, bias=True) (edge_to_edge): Linear(in_features=13, out_features=256, bias=True) (update_edge): Linear(in_features=768, out_features=256, bias=True) ) ...
How can I solve this error? Thanks a lot for your help!
The text was updated successfully, but these errors were encountered:
Can you take a look at #162 ? I believe it's the same issue.
Sorry, something went wrong.
Ok, that works! Thanks a lot!
No branches or pull requests
Hi, I am trying to use pre-trained model on ESOL dataset.
This throws the following error
I check the shape of graph and the construction of WeavePredictor, find they are not match
How can I solve this error? Thanks a lot for your help!
The text was updated successfully, but these errors were encountered: