Skip to content

Commit

Permalink
Optimize the speed of _compute_3body implementation (#283)
Browse files Browse the repository at this point in the history
* improve TensorNet model coverage

* Update pyproject.toml

Signed-off-by: Tsz Wai Ko <[email protected]>

* Improve the unit test for SO(3) equivarance in TensorNet class

* improve SO3Net model class coverage and simplify TensorNet implementations

* improve the coverage in MLP_norm class

* Improve the implementation of three-body interactions

* fixed black

* Optimize the speed of _compute_3body class

---------

Signed-off-by: Tsz Wai Ko <[email protected]>
  • Loading branch information
kenko911 authored Jul 5, 2024
1 parent da8da3c commit d59abe2
Showing 1 changed file with 20 additions and 25 deletions.
45 changes: 20 additions & 25 deletions src/matgl/graph/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def create_line_graph(g: dgl.DGLGraph, threebody_cutoff: float, directed: bool =
if directed:
lg = _create_directed_line_graph(graph_with_three_body, threebody_cutoff)
else:
lg, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = _compute_3body(graph_with_three_body)
lg = _compute_3body(graph_with_three_body)

return lg

Expand Down Expand Up @@ -174,49 +174,44 @@ def _compute_3body(g: dgl.DGLGraph):
n_triple_s (np.ndarray): number of three-body angles for each structure
"""
n_atoms = g.num_nodes()
first_col = g.edges()[0].cpu().numpy().reshape(-1, 1)
all_indices = np.arange(n_atoms).reshape(1, -1)
n_bond_per_atom = np.count_nonzero(first_col == all_indices, axis=0)
first_col = g.edges()[0].cpu().numpy()

# Count bonds per atom efficiently
n_bond_per_atom = np.bincount(first_col, minlength=n_atoms)

n_triple_i = n_bond_per_atom * (n_bond_per_atom - 1)
n_triple = np.sum(n_triple_i)
n_triple = n_triple_i.sum()
n_triple_ij = np.repeat(n_bond_per_atom - 1, n_bond_per_atom)
triple_bond_indices = np.empty((n_triple, 2), dtype=matgl.int_np) # type: ignore

triple_bond_indices = np.empty((n_triple, 2), dtype=matgl.int_np)

start = 0
cs = 0
for n in n_bond_per_atom:
if n > 0:
"""
triple_bond_indices is generated from all pair permutations of atom indices. The
numpy version below does this with much greater efficiency. The equivalent slow
code is:
```
for j, k in itertools.permutations(range(n), 2):
triple_bond_indices[index] = [start + j, start + k]
```
"""
r = np.arange(n)
x, y = np.meshgrid(r, r, indexing="xy")
c = np.stack([y.ravel(), x.ravel()], axis=1)
final = c[c[:, 0] != c[:, 1]]
triple_bond_indices[start : start + (n * (n - 1)), :] = final + cs
final = np.stack([y.ravel(), x.ravel()], axis=1)
mask = final[:, 0] != final[:, 1]
final = final[mask]
triple_bond_indices[start : start + n * (n - 1)] = final + cs
start += n * (n - 1)
cs += n

n_triple_s = [np.sum(n_triple_i[0:n_atoms])]
src_id = torch.tensor(triple_bond_indices[:, 0], dtype=matgl.int_th)
dst_id = torch.tensor(triple_bond_indices[:, 1], dtype=matgl.int_th)
l_g = dgl.graph((src_id, dst_id)).to(g.device)
three_body_id = torch.concatenate(l_g.edges())
n_triple_ij = torch.tensor(n_triple_ij, dtype=matgl.int_th).to(g.device)
max_three_body_id = torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0
three_body_id = torch.cat(l_g.edges())
n_triple_ij = torch.tensor(n_triple_ij, dtype=matgl.int_th, device=g.device)

max_three_body_id = three_body_id.max().item() + 1 if three_body_id.numel() > 0 else 0

l_g.ndata["bond_dist"] = g.edata["bond_dist"][:max_three_body_id]
l_g.ndata["bond_vec"] = g.edata["bond_vec"][:max_three_body_id]
l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][:max_three_body_id]
l_g.ndata["n_triple_ij"] = n_triple_ij[:max_three_body_id]
n_triple_s = torch.tensor(n_triple_s, dtype=matgl.int_th) # type: ignore
return l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s

return l_g


def _create_directed_line_graph(graph: dgl.DGLGraph, threebody_cutoff: float) -> dgl.DGLGraph:
Expand Down

0 comments on commit d59abe2

Please sign in to comment.