Skip to content

Commit

Permalink
draft symmetric tucker
Browse files Browse the repository at this point in the history
  • Loading branch information
CheukHinHoJerry committed Oct 6, 2024
1 parent 3382ad7 commit 731b3e8
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 53 deletions.
95 changes: 82 additions & 13 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
mask_head,
reshape_irreps,
tp_out_irreps_with_instructions,
make_tp_irreps
)
from .radial import (
AgnesiTransform,
Expand All @@ -31,7 +32,6 @@
)
from .symmetric_contraction import SymmetricContraction


@compile_mode("script")
class LinearNodeEmbeddingBlock(torch.nn.Module):
def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps):
Expand Down Expand Up @@ -208,6 +208,48 @@ def forward(
radial = self.bessel_fn(edge_lengths) # [n_edges, n_basis]
return radial * cutoff # [n_edges, n_basis]

def tensor_power_einsum(tensor, N):
batch_size, dim, features = tensor.shape

# Create the equation string
indices = [chr(ord('a') + i) for i in range(N)] # Generate indices like 'a', 'b', 'c', ...
eq = ','.join(['bi' + 'f' for _ in range(N)]) + '->b' + ''.join(indices) + 'f'

# Prepare the list of tensors
tensors = [tensor] * N

# Perform einsum
result = torch.einsum(eq, *tensors)

# Reshape to [batch_size, dim ** N, features]
result = result.reshape(batch_size, dim ** N, features)
return result

@compile_mode("script")
class TensorFormatBlock(torch.nn.Module):
def __init__(self, tensor_format, correlation):
super().__init__()

self.tensor_format = tensor_format
self.correlation = correlation
#self.irreps_in = irreps_in
#self.indices = [chr(ord('a') + i) for i in range(N)]
#self.eq = ','.join(['bi' + 'f' for _ in range(N)]) + '->b' + ''.join(indices) + 'f'

def forward(self, message) -> torch.Tensor:
batch_size, dim, features = message.shape
if self.tensor_format == "symmetric_cp":
return message
elif self.tensor_format == "symmetric_tucker":
return message
# message = [message] * correlation
# message = torch.einsum(eq, *tensors)
# # K = message.shape[-2]
# # for i in range(self.correlation - 1):
# # message = message.unsqueeze(-2)
# # message = message.repeat([1, 1, ] + [K, ] * (self.correlation - 1) + [1])
# return message.reshape(batch_size, dim ** correlation, features)


@compile_mode("script")
class EquivariantProductBasisBlock(torch.nn.Module):
Expand All @@ -218,6 +260,7 @@ def __init__(
correlation: int,
use_sc: bool = True,
num_elements: Optional[int] = None,
tensor_format = "symmetric_cp",
) -> None:
super().__init__()

Expand All @@ -227,14 +270,24 @@ def __init__(
irreps_out=target_irreps,
correlation=correlation,
num_elements=num_elements,
tensor_format=tensor_format
)
# Update linear
self.linear = o3.Linear(
target_irreps,
target_irreps,
internal_weights=True,
shared_weights=True,
)
if tensor_format == "symmetric_cp":
self.linear = o3.Linear(
target_irreps,
target_irreps,
internal_weights=True,
shared_weights=True,
)
elif tensor_format == "symmetric_tucker":
tucker_irreps = make_tp_irreps(target_irreps, correlation)
self.linear = o3.Linear(
tucker_irreps,
target_irreps,
internal_weights=True,
shared_weights=True,
)

def forward(
self,
Expand All @@ -243,6 +296,7 @@ def forward(
node_attrs: torch.Tensor,
) -> torch.Tensor:
node_feats = self.symmetric_contractions(node_feats, node_attrs)
print("shape after symmstric contractions: ", node_feats.shape)
if self.use_sc and sc is not None:
return self.linear(node_feats) + sc
return self.linear(node_feats)
Expand All @@ -259,7 +313,9 @@ def __init__(
target_irreps: o3.Irreps,
hidden_irreps: o3.Irreps,
avg_num_neighbors: float,
correlation: int,
radial_MLP: Optional[List[int]] = None,
tensor_format: str = "symmetric_cp",
) -> None:
super().__init__()
self.node_attrs_irreps = node_attrs_irreps
Expand All @@ -272,7 +328,9 @@ def __init__(
if radial_MLP is None:
radial_MLP = [64, 64, 64]
self.radial_MLP = radial_MLP

self.tensor_format = tensor_format
self.correlation = correlation

self._setup()

@abstractmethod
Expand Down Expand Up @@ -630,12 +688,21 @@ def _setup(self) -> None:
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
# "4x0e + 4x1o"
# "4**corrlatiox0e + 4**correlationx1o"

# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps
)
if self.tensor_format == "symmetric_cp":
self.skip_tp = o3.FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps
)
elif self.tensor_format == "symmetric_tucker":
self.skip_tp = o3.FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, \
self.hidden_irreps
)
self.reshape = reshape_irreps(self.irreps_out)
self.tensor_format_layer = TensorFormatBlock(self.tensor_format, self.correlation)

def forward(
self,
Expand All @@ -657,11 +724,13 @@ def forward(
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]

message = self.linear(message) / self.avg_num_neighbors
return (
self.reshape(message),
self.tensor_format_layer(self.reshape(message)),
sc,
) # [n_nodes, channels, (lmax + 1)**2]
) # symmetric_cp: [n_nodes, channels, (lmax + 1)**2]
# symmetric_tucker: [n_nodes,] + [channels] * correlation + [(lmax+1)**2 ,]


@compile_mode("script")
Expand Down
9 changes: 9 additions & 0 deletions mace/modules/irreps_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@
from e3nn import o3
from e3nn.util.jit import compile_mode

def make_tp_irreps(target_irreps, correlation):
"""
multiply irreps from eg. 4x0e + 4x1o -> 4**correlation + 4 ** correlation
"""
tp_irreps = o3.Irreps()
for ir in target_irreps:
tmp_irreps = o3.Irreps(str(ir))
tp_irreps += (tmp_irreps * ((tmp_irreps.num_irreps) ** (correlation - 1))).simplify()
return tp_irreps

# Based on mir-group/nequip
def tp_out_irreps_with_instructions(
Expand Down
26 changes: 23 additions & 3 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
get_symmetric_displacement,
)

from .irreps_tools import (
make_tp_irreps
)
# pylint: disable=C0302


Expand Down Expand Up @@ -62,6 +65,7 @@ def __init__(
radial_MLP: Optional[List[int]] = None,
radial_type: Optional[str] = "bessel",
heads: Optional[List[str]] = None,
tensor_format = "symmetric_cp",
):
super().__init__()
self.register_buffer(
Expand Down Expand Up @@ -115,7 +119,10 @@ def __init__(
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps,
avg_num_neighbors=avg_num_neighbors,
correlation=correlation[0],
radial_MLP=radial_MLP,
#
tensor_format=tensor_format,
)
self.interactions = torch.nn.ModuleList([inter])

Expand All @@ -131,21 +138,33 @@ def __init__(
correlation=correlation[0],
num_elements=num_elements,
use_sc=use_sc_first,
#
tensor_format=tensor_format,
)
self.products = torch.nn.ModuleList([prod])

self.readouts = torch.nn.ModuleList()
self.readouts.append(
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)
if tensor_format == "symmetric_cp":
self.readouts.append(
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)
elif tensor_format == "symmetric_tucker":
self.readouts.append(
#LinearReadoutBlock(make_tp_irreps(hidden_irreps, correlation[0]), o3.Irreps(f"{len(heads)}x0e"))
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)

for i in range(num_interactions - 1):
if i == num_interactions - 2:
hidden_irreps_out = str(
hidden_irreps[0]
) # Select only scalars for last layer
# if tensor_format == "symmetric_tucker":
# hidden_irreps_out = str(make_tp_irreps(o3.Irreps(hidden_irreps_out), correlation[i+1]))
else:
hidden_irreps_out = hidden_irreps
# if tensor_format == "symmetric_tucker":
# hidden_irreps_out = str(make_tp_irreps(o3.Irreps(hidden_irreps_out), correlation[i+1]))
inter = interaction_cls(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=hidden_irreps,
Expand All @@ -155,6 +174,7 @@ def __init__(
hidden_irreps=hidden_irreps_out,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
correlation=correlation[i + 1]
)
self.interactions.append(inter)
prod = EquivariantProductBasisBlock(
Expand Down
Loading

0 comments on commit 731b3e8

Please sign in to comment.