Skip to content

Commit

Permalink
add soft_normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Apr 7, 2023
1 parent 9fb8ed2 commit d988c08
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
1 change: 1 addition & 0 deletions configs/aspirin_small.gin
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ model.avg_num_neighbors = "average"
# 7: -1484.9814568572233,
# 8: -2041.9816003861047
# }
model.soft_normalization = 1000.0


loss.energy_weight = 1.0
Expand Down
19 changes: 19 additions & 0 deletions mace_jax/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
epsilon: Optional[float] = None,
correlation: int = 3, # Correlation order at each layer (~ node_features^correlation), default 3
gate: Callable = jax.nn.silu, # activation function
soft_normalization: Optional[float] = None,
symmetric_tensor_product_basis: bool = True,
off_diagonal: bool = False,
interaction_irreps: Union[str, e3nn.Irreps] = "o3_restricted", # or o3_full
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
self.symmetric_tensor_product_basis = symmetric_tensor_product_basis
self.off_diagonal = off_diagonal
self.max_ell = max_ell
self.soft_normalization = soft_normalization

# Embeddings
self.node_embedding = node_embedding(
Expand Down Expand Up @@ -154,6 +156,7 @@ def __call__(
readout_mlp_irreps=self.readout_mlp_irreps,
symmetric_tensor_product_basis=self.symmetric_tensor_product_basis,
off_diagonal=self.off_diagonal,
soft_normalization=self.soft_normalization,
name=f"layer_{i}",
)(
vectors,
Expand Down Expand Up @@ -189,6 +192,7 @@ def __init__(
correlation: int,
symmetric_tensor_product_basis: bool,
off_diagonal: bool,
soft_normalization: Optional[float],
# ReadoutBlock:
output_irreps: e3nn.Irreps,
readout_mlp_irreps: e3nn.Irreps,
Expand All @@ -210,6 +214,7 @@ def __init__(
self.readout_mlp_irreps = readout_mlp_irreps
self.symmetric_tensor_product_basis = symmetric_tensor_product_basis
self.off_diagonal = off_diagonal
self.soft_normalization = soft_normalization

def __call__(
self,
Expand Down Expand Up @@ -283,6 +288,20 @@ def __call__(
f"{self.name}: tensor power", node_feats, node_mask[:, None]
)

if self.soft_normalization is not None:

def phi(n):
n = n / self.soft_normalization
return 1.0 / (1.0 + n * e3nn.sus(n))

node_feats = e3nn.norm_activation(
node_feats, [phi] * len(node_feats.irreps)
)

node_feats = profile(
f"{self.name}: soft normalization", node_feats, node_mask[:, None]
)

if sc is not None:
node_feats = node_feats + sc # [n_nodes, feature * hidden_irreps]

Expand Down

0 comments on commit d988c08

Please sign in to comment.