-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathmace_irreps_tools.py
86 lines (69 loc) · 2.89 KB
/
mace_irreps_tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
###########################################################################################
# Elementary tools for handling irreducible representations
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from typing import List, Tuple
import torch
from e3nn import o3
from e3nn.util.jit import compile_mode
# Based on mir-group/nequip
def tp_out_irreps_with_instructions(
irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps
) -> Tuple[o3.Irreps, List]:
trainable = True
# Collect possible irreps and their instructions
irreps_out_list: List[Tuple[int, o3.Irreps]] = []
instructions = []
for i, (mul, ir_in) in enumerate(irreps1):
for j, (_, ir_edge) in enumerate(irreps2):
for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2
if ir_out in target_irreps:
k = len(irreps_out_list) # instruction index
irreps_out_list.append((mul, ir_out))
instructions.append((i, j, k, "uvu", trainable))
# We sort the output irreps of the tensor product so that we can simplify them
# when they are provided to the second o3.Linear
irreps_out = o3.Irreps(irreps_out_list)
irreps_out, permut, _ = irreps_out.sort()
# Permute the output indexes of the instructions to match the sorted irreps:
instructions = [
(i_in1, i_in2, permut[i_out], mode, train)
for i_in1, i_in2, i_out, mode, train in instructions
]
instructions = sorted(instructions, key=lambda x: x[2])
return irreps_out, instructions
def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps:
# Assuming simplified irreps
irreps_mid = []
for _, ir_in in irreps:
found = False
for mul, ir_out in target_irreps:
if ir_in == ir_out:
irreps_mid.append((mul, ir_out))
found = True
break
if not found:
raise RuntimeError(f"{ir_in} not in {target_irreps}")
return o3.Irreps(irreps_mid)
@compile_mode("script")
class reshape_irreps(torch.nn.Module):
def __init__(self, irreps: o3.Irreps) -> None:
super().__init__()
self.irreps = o3.Irreps(irreps)
self.dims = []
self.muls = []
for mul, ir in self.irreps:
d = ir.dim
self.dims.append(d)
self.muls.append(mul)
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
ix = 0
out = []
batch, _ = tensor.shape
for mul, d in zip(self.muls, self.dims):
field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr]
ix += mul * d
field = field.reshape(batch, mul, d)
out.append(field)
return torch.cat(out, dim=-1)