Skip to content

Commit

Permalink
add special type of tucker format
Browse files Browse the repository at this point in the history
  • Loading branch information
CheukHinHoJerry committed Nov 5, 2024
1 parent ac44b5e commit c4c4030
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 34 deletions.
28 changes: 18 additions & 10 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import math
import torch.nn.functional
from e3nn import nn, o3
from e3nn.util.jit import compile_mode
Expand All @@ -21,7 +22,8 @@
reshape_irreps,
tp_out_irreps_with_instructions,
make_tp_irreps,
make_tucker_irreps
make_tucker_irreps,
make_tucker_irreps_flexible
)
from .radial import (
AgnesiTransform,
Expand Down Expand Up @@ -242,7 +244,7 @@ def __init__(self, tensor_format, correlation):

def forward(self, message) -> torch.Tensor:
batch_size, dim, features = message.shape
if self.tensor_format in ["symmetric_cp", "symmetric_tucker"]:
if self.tensor_format in ["symmetric_cp", "symmetric_tucker", "flexible_symmetric_tucker"]:
return message
elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]:
return message
Expand Down Expand Up @@ -277,8 +279,13 @@ def __init__(
internal_weights=True,
shared_weights=True,
)
elif tensor_format in ["symmetric_tucker", "non_symmetric_tucker"]:
tucker_irreps = make_tucker_irreps(target_irreps, correlation)
elif tensor_format in ["flexible_symmetric_tucker", "symmetric_tucker", "non_symmetric_tucker"]:
if tensor_format == "flexible_symmetric_tucker":
tucker_irreps = make_tucker_irreps_flexible(target_irreps, correlation)
else:
tucker_irreps = make_tucker_irreps(target_irreps, correlation)
print("tucker irreps:", tucker_irreps)
print("target irreps:", target_irreps)
self.linear = o3.Linear(
tucker_irreps,
target_irreps,
Expand All @@ -293,7 +300,6 @@ def forward(
node_attrs: torch.Tensor,
) -> torch.Tensor:
node_feats = self.symmetric_contractions(node_feats, node_attrs)
print("shape after symmstric contractions: ", node_feats.shape)
if self.use_sc and sc is not None:
return self.linear(node_feats) + sc
return self.linear(node_feats)
Expand Down Expand Up @@ -683,7 +689,7 @@ def _setup(self) -> None:
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps

if self.tensor_format in ["symmetric_cp", "symmetric_tucker"]:
if self.tensor_format in ["symmetric_cp", "symmetric_tucker", "flexible_symmetric_tucker"]:
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
Expand Down Expand Up @@ -730,19 +736,21 @@ def forward(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]

if self.tensor_format in ["symmetric_cp", "symmetric_tucker"]:
if self.tensor_format in ["symmetric_cp", "symmetric_tucker", "flexible_symmetric_tucker"]:
message = self.linear(original_message) / self.avg_num_neighbors
return (
if self.tensor_format in ["flexible_symmetric_tucker", ]:
return (message, sc)
else:
return (
self.tensor_format_layer(self.reshape(message)),
sc,
) # symmetric_cp: [n_nodes, channels, (lmax + 1)**2]
) # symmetric_cp: [n_nodes, channels, (lmax + 1)**2]
elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]:
message = self.reshape[0](self.linear[0](original_message))
message = message.unsqueeze(-1)
for idx in range(1, self.correlation):
_message = self.reshape[idx](self.linear[idx](original_message)).unsqueeze(-1)
message = torch.cat((message, _message), dim = -1)
print("shape of message: ", message.shape)
return (
message / self.avg_num_neighbors,
sc
Expand Down
32 changes: 32 additions & 0 deletions mace/modules/irreps_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import List, Tuple

import torch
import math
from e3nn import o3
from e3nn.util.jit import compile_mode

Expand All @@ -30,6 +31,16 @@ def make_tucker_irreps(target_irreps, correlation):
tp_irreps += o3.Irreps(f"{num_feats}x{tmp_irreps[0].ir}")
return tp_irreps

def make_tucker_irreps_flexible(target_irreps, correlation):
tp_irreps = o3.Irreps()
for ir in target_irreps:
tmp_irreps = o3.Irreps(str(ir))
num_feats = 0
for nu in range(1, correlation + 1):
num_feats += (math.ceil(ir.mul ** (1 / nu))) ** nu
tp_irreps += o3.Irreps(f"{num_feats}x{tmp_irreps[0].ir}")
return tp_irreps

# Based on mir-group/nequip
def tp_out_irreps_with_instructions(
irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps
Expand Down Expand Up @@ -104,6 +115,27 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
out.append(field)
return torch.cat(out, dim=-1)

@compile_mode("script")
class inverse_reshape_irreps(torch.nn.Module):
def __init__(self, irreps: o3.Irreps) -> None:
super().__init__()
self.irreps = o3.Irreps(irreps)
self.dims = []
self.muls = []
for mul, ir in self.irreps:
d = ir.dim
self.dims.append(d)
self.muls.append(mul)

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
out = []
ix = 0
for mul, d in zip(self.muls, self.dims):
field = tensor[:, ix : ix + mul, :d] # [batch, mul, repr]
field = field.reshape(tensor.shape[0], -1) # Flatten [batch, mul * repr]
out.append(field)
ix += mul
return torch.cat(out, dim=-1)

def mask_head(x: torch.Tensor, head: torch.Tensor, num_heads: int) -> torch.Tensor:
mask = torch.zeros(x.shape[0], x.shape[1] // num_heads, num_heads, device=x.device)
Expand Down
3 changes: 1 addition & 2 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def __init__(
use_sc_first = True

node_feats_irreps_out = inter.target_irreps
print(node_feats_irreps_out)
prod = EquivariantProductBasisBlock(
node_feats_irreps=node_feats_irreps_out,
target_irreps=hidden_irreps,
Expand All @@ -149,7 +148,7 @@ def __init__(
self.readouts.append(
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)
elif tensor_format in ["symmetric_tucker", "non_symmetric_tucker"]:
elif tensor_format in ["symmetric_tucker", "non_symmetric_tucker", "flexible_symmetric_tucker"]:
self.readouts.append(
#LinearReadoutBlock(make_tp_irreps(hidden_irreps, correlation[0]), o3.Irreps(f"{len(heads)}x0e"))
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
Expand Down
95 changes: 83 additions & 12 deletions mace/modules/symmetric_contraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@

import opt_einsum_fx
import torch
import math
import torch.fx
from e3nn import o3
from e3nn.util.codegen import CodeGenMixin
from e3nn.util.jit import compile_mode

from mace.tools.cg import U_matrix_real
from .irreps_tools import reshape_irreps, inverse_reshape_irreps

BATCH_EXAMPLE = 10
ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"]
Expand Down Expand Up @@ -74,6 +76,7 @@ def __init__(
Contraction(
irreps_in=self.irreps_in,
irrep_out=o3.Irreps(str(irrep_out.ir)),
irrep_out_withmul = irrep_out,
correlation=correlation[irrep_out],
internal_weights=self.internal_weights,
num_elements=num_elements,
Expand All @@ -94,6 +97,7 @@ def __init__(
self,
irreps_in: o3.Irreps,
irrep_out: o3.Irreps,
irrep_out_withmul: o3.Irreps,
correlation: int,
internal_weights: bool = True,
num_elements: Optional[int] = None,
Expand All @@ -104,8 +108,10 @@ def __init__(
self.num_features = irreps_in.count((0, 1))
self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in])
self.correlation = correlation
self.irreps_in = irreps_in
self.irrep_out = irrep_out

self.irrep_out_withmul = irrep_out_withmul
dtype = torch.get_default_dtype()
for nu in range(1, correlation + 1):
U_matrix = U_matrix_real(
Expand All @@ -118,6 +124,42 @@ def __init__(
self.register_buffer(f"U_matrix_{nu}", U_matrix)

self.tensor_format = tensor_format

# if this is tucker format, in order to allow
# more A_klm basis to be formed and prevent the k^nu
# scaling we need this further contraction
self.irreps_mid = o3.Irreps()

# control dimension flexibly for each nu
# TODO: generalize to allow more flexible dimension
self.irreps_nu = []
self.linear_nu = torch.nn.ModuleList([])
self.linear_nu_reshape = torch.nn.ModuleList([])

if self.tensor_format in ["flexible_symmetric_tucker", ]:
for irrep_in in self.irreps_in:
self.irreps_mid += o3.Irreps(f"{irrep_out_withmul.mul}x{irrep_in.ir}")
self.linear = o3.Linear(self.irreps_in,
self.irreps_mid,
internal_weights=True,
shared_weights=True,)

#self.reshape = reshape_irreps(self.irreps_mid)
# update num_features too
self.num_features = self.irreps_mid.count((0, 1))

for nu in range(correlation, 0, -1):
tmp_irreps = o3.Irreps()
for irrep_mid in self.irreps_mid:
tmp_irreps += o3.Irreps(f"{math.ceil((irrep_mid.mul) ** (1 / nu))}x{irrep_mid.ir}")
self.irreps_nu.append(tmp_irreps)
self.linear_nu.append(o3.Linear(self.irreps_mid,
tmp_irreps,
internal_weights=True,
shared_weights=True
))
self.linear_nu_reshape.append(reshape_irreps(tmp_irreps))

# Tensor contraction equations
self.contractions_weighting = torch.nn.ModuleList()
self.contractions_features = torch.nn.ModuleList()
Expand All @@ -138,6 +180,11 @@ def __init__(
torch.randn((num_elements, num_params, self.num_features))
/ num_params
)
elif tensor_format == "flexible_symmetric_tucker":
w = torch.nn.Parameter(
torch.randn([num_elements, num_params,] + [math.ceil(self.num_features ** (1 / correlation)),])
/ num_params
)
elif tensor_format == "symmetric_tucker":
w = torch.nn.Parameter(
torch.randn([num_elements, num_params,] + [self.num_features,])
Expand Down Expand Up @@ -189,6 +236,13 @@ def __init__(
torch.randn((num_elements, num_params, self.num_features))
/ num_params
)
elif tensor_format == "flexible_symmetric_tucker":
# to be outer produced in model.forward to form symemtrized parameter tensor
# this can be improved
w = torch.nn.Parameter(
torch.randn((num_elements, num_params, math.ceil(self.num_features ** (1 / i))))
/ num_params
)
elif tensor_format == "symmetric_tucker":
# to be outer produced in model.forward to form symemtrized parameter tensor
# this can be improved
Expand Down Expand Up @@ -259,9 +313,14 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
irrep_out = self.irrep_out
num_equivariance = 2 * irrep_out.lmax + 1
if "tucker" in self.tensor_format:
#
# this allow generalization for different num_feats for
# different level of L
if self.tensor_format == "flexible_symmetric_tucker":
x = self.linear(x) #self.reshape(self.linear(x))

outs = dict()
out_channel_idx = "".join([CHANNEL_ALPHANET[j] for j in range(self.correlation)])
idx = 0
for nu in range(self.correlation, 0, -1):
num_params = self.U_tensors(nu).size()[-1]
num_ell = self.U_tensors(nu).size()[-2]
Expand All @@ -278,9 +337,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
+ ["k"]
)),
self.U_tensors(nu),
x_nu)
self.linear_nu_reshape[idx](self.linear_nu[idx](x_nu)) \
if self.tensor_format=="flexible_symmetric_tucker" else x_nu)
else:
# contractions to be done for U_tensors(nu)
idx2 = 0
for nu2 in range(self.correlation, nu - 1, -1):
# contraction for current nu
# [ALPHABET[j] for j in range(nu + min(irrep_out.lmax, 1) - 1)]
Expand All @@ -294,7 +355,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
))
,
self.U_tensors(nu),
x_nu
self.linear_nu_reshape[idx](self.linear_nu[idx](x_nu)) \
if self.tensor_format=="flexible_symmetric_tucker" else x_nu
)
# also contract previous nu and expand the tensor product basis
else:
Expand All @@ -307,15 +369,22 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
+ ["k"]
),
outs[nu2],
x_nu
)

self.linear_nu_reshape[idx2](self.linear_nu[idx2](x_nu)) \
if self.tensor_format=="flexible_symmetric_tucker" else x_nu)
idx2 += 1 # for each nu2

idx += 1 # for each nu


# for nu in range(self.correlation, 0, -1):
# print(f"outs[{nu}] shape: ", outs[nu].shape)
# print("before product basis")
# product basis coeffcients layer
for nu in range(self.correlation, 0, -1):
if nu == self.correlation:
if self.tensor_format == "non_symmetric_tucker":
c_tensor = torch.einsum(f"ek{out_channel_idx[:nu]},be->bk{out_channel_idx[:nu]}", self.weights_max, y)
elif self.tensor_format == "symmetric_tucker":
elif self.tensor_format in ["symmetric_tucker", "flexible_symmetric_tucker"]:
c_tensor = torch.einsum("ekc,be->bkc", self.weights_max, y)
# outer product to symmetrize tensor
c_tensor = torch.einsum("".join([f"bk{out_channel_idx[i]}," for i in range(nu-1)]
Expand All @@ -327,7 +396,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
else:
if self.tensor_format == "non_symmetric_tucker":
c_tensor = torch.einsum(f"ek{out_channel_idx[:nu]},be->bk{out_channel_idx[:nu]}", self.weights[self.correlation - nu - 1], y)
elif self.tensor_format == "symmetric_tucker":
elif self.tensor_format in ["symmetric_tucker", "flexible_symmetric_tucker"]:
c_tensor = torch.einsum("ekc,be->bkc", self.weights[self.correlation - nu - 1], y)
# outer product to symmetrize tensor
c_tensor = torch.einsum("".join([f"bk{out_channel_idx[i]}," for i in range(nu-1)]
Expand All @@ -349,17 +418,19 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
c_tensor,
)
for nu in range(self.correlation, 0, -1):
shape_outnu = [outs[nu].shape[0]] + [self.num_features] * nu
if self.tensor_format == "flexible_symmetric_tucker":
shape_outnu = [outs[nu].shape[0]] + [math.ceil(self.num_features ** (1 / nu))] * nu
else:
shape_outnu = [outs[nu].shape[0]] + [self.num_features] * nu
if irrep_out.lmax > 0:
shape_outnu += [num_equivariance]
# combine all the features channels
outs[nu] = outs[nu].reshape(*shape_outnu)
# reshape kLM
outs[nu] = outs[nu].reshape(outs[nu].shape[0], -1)

# / factorial(nu) because of extra work done for convenience
return torch.cat([outs[nu] for nu in range(self.correlation, 0, -1)], dim = 1)


## previous CP implementation
elif "cp" in self.tensor_format:
if self.tensor_format == "symmetric_cp":
out = self.graph_opt_main(
Expand Down
1 change: 1 addition & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"non_symmetric_cp",
"symmetric_tucker",
"non_symmetric_tucker",
"flexible_symmetric_tucker"
]
)

Expand Down
14 changes: 8 additions & 6 deletions mace/tools/arg_parser_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@ def check_args(args):
.sort()
.irreps.simplify()
)
assert (
len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1
), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"
if args.tensor_format in ["symmetric_cp", "symmetric_tucker", "non_symmetric_cp", "non_symmetric_tucker"]:
assert (
len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1
), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"
elif args.hidden_irreps is not None:
assert (
len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1
), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"
if args.tensor_format in ["symmetric_cp", "symmetric_tucker", "non_symmetric_cp", "non_symmetric_tucker"]:
assert (
len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1
), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"

args.num_channels = list(
{irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}
Expand Down
Loading

0 comments on commit c4c4030

Please sign in to comment.