Skip to content

Commit

Permalink
Update equibind.py - copy_e
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasTR authored Apr 26, 2023
1 parent 91c6ba2 commit 56849b8
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions models/equibind.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,8 @@ def forward(self, lig_graph, rec_graph, coords_lig, h_feats_lig, original_ligand
else:
x_evolved_rec = coords_rec

lig_graph.update_all(fn.copy_edge('msg', 'm'), fn.mean('m', 'aggr_msg'))
rec_graph.update_all(fn.copy_edge('msg', 'm'), fn.mean('m', 'aggr_msg'))
lig_graph.update_all(fn.copy_e('msg', 'm'), fn.mean('m', 'aggr_msg'))
rec_graph.update_all(fn.copy_e('msg', 'm'), fn.mean('m', 'aggr_msg'))

if self.fine_tune:
x_evolved_lig = x_evolved_lig + self.att_mlp_cross_coors_V_lig(h_feats_lig) * (
Expand Down Expand Up @@ -515,7 +515,7 @@ def forward(self, lig_graph, rec_graph, coords_lig, h_feats_lig, original_ligand
Loss = torch.sum((d_squared - geometry_graph.edata['feat'] ** 2)**2) # this is the loss whose gradient we are calculating here
grad_d_squared = 2 * (x_evolved_lig[src] - x_evolved_lig[dst])
geometry_graph.edata['partial_grads'] = 2 * (d_squared - geometry_graph.edata['feat'] ** 2)[:,None] * grad_d_squared
geometry_graph.update_all(fn.copy_edge('partial_grads', 'partial_grads_msg'),
geometry_graph.update_all(fn.copy_e('partial_grads', 'partial_grads_msg'),
fn.sum('partial_grads_msg', 'grad_x_evolved'))
grad_x_evolved = geometry_graph.ndata['grad_x_evolved']
x_evolved_lig = x_evolved_lig + self.geometry_reg_step_size * grad_x_evolved
Expand Down

0 comments on commit 56849b8

Please sign in to comment.