Skip to content

Commit

Permalink
Modified united tests to improve the coverage score
Browse files Browse the repository at this point in the history
  • Loading branch information
kenko911 committed Feb 12, 2024
1 parent 12038df commit 12bef9d
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 28 deletions.
64 changes: 44 additions & 20 deletions src/matgl/layers/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
import matgl
from matgl.layers._core import MLP
from matgl.utils.cutoff import cosine_cutoff
from matgl.utils.maths import new_radial_tensor, tensor_norm, vector_to_skewtensor, vector_to_symtensor
from matgl.utils.maths import (
new_radial_tensor,
tensor_norm,
vector_to_skewtensor,
vector_to_symtensor,
)


class EmbeddingBlock(nn.Module):
Expand Down Expand Up @@ -51,8 +56,18 @@ def __init__(
self.activation = activation
if ntypes_state and dim_state_embedding is not None:
self.layer_state_embedding = nn.Embedding(ntypes_state, dim_state_embedding) # type: ignore
elif dim_state_feats is not None:
self.layer_state_embedding = nn.Sequential(
nn.LazyLinear(dim_state_feats, bias=False, dtype=matgl.float_th),
activation,
)
if ntypes_node is not None:
self.layer_node_embedding = nn.Embedding(ntypes_node, dim_node_embedding)
else:
self.layer_node_embedding = nn.Sequential(
nn.LazyLinear(dim_node_embedding, bias=False, dtype=matgl.float_th),
activation,
)
if dim_edge_embedding is not None:
dim_edges = [degree_rbf, dim_edge_embedding]
self.layer_edge_embedding = MLP(dim_edges, activation=activation, activate_last=True)
Expand All @@ -73,8 +88,7 @@ def forward(self, node_attr, edge_attr, state_attr):
if self.ntypes_node is not None:
node_feat = self.layer_node_embedding(node_attr)
else:
node_embed = MLP([node_attr.shape[-1], self.dim_node_embedding], activation=self.activation)
node_feat = node_embed(node_attr.to(matgl.float_th))
node_feat = self.layer_node_embedding(node_attr.to(matgl.float_th))
if self.dim_edge_embedding is not None:
edge_feat = self.layer_edge_embedding(edge_attr.to(matgl.float_th))
else:
Expand All @@ -84,8 +98,7 @@ def forward(self, node_attr, edge_attr, state_attr):
state_feat = self.layer_state_embedding(state_attr)
elif self.dim_state_feats is not None:
state_attr = torch.unsqueeze(state_attr, 0)
state_embed = MLP([state_attr.shape[-1], self.dim_state_feats], activation=self.activation)
state_feat = state_embed(state_attr.to(matgl.float_th))
state_feat = self.layer_state_embedding(state_attr.to(matgl.float_th))
else:
state_feat = state_attr
else:
Expand All @@ -109,7 +122,7 @@ def __init__(
include_state: bool = False,
ntypes_state: int | None = None,
dim_state_feats: int | None = None,
dim_state_embedding: int | None = None,
dim_state_embedding: int = 0,
):
"""
Args:
Expand Down Expand Up @@ -140,26 +153,32 @@ def __init__(
self.linears_scalar.append(nn.Linear(2 * units, 3 * units, bias=True, dtype=dtype))
self.init_norm = nn.LayerNorm(units, dtype=dtype)
self.cutoff = cutoff
if ntypes_state and include_state is not None:
if ntypes_state is not None and dim_state_embedding > 0:
self.layer_state_embedding = nn.Embedding(ntypes_state, dim_state_embedding) # type: ignore
self.emb2 = nn.Linear(2 * units + dim_state_embedding, units, dtype=dtype) # type: ignore
elif dim_state_feats is not None:
self.layer_state_mlp = nn.Sequential(nn.LazyLinear(dim_state_feats, bias=False, dtype=dtype), activation)
self.emb2 = nn.Linear(2 * units + dim_state_feats, units, dtype=dtype)
else:
self.layer_state_embedding = None
self.emb2 = nn.Linear(2 * units, units, dtype=dtype) # type: ignore
self.emb2 = nn.Linear(2 * units, units, dtype=dtype)
self.emb3 = nn.Linear(degree_rbf, units)
self.reset_parameters()
self.dim_state_fests = dim_state_feats
self.dim_state_feats = dim_state_feats
self.include_state = include_state
self.ntypes_state = ntypes_state
self.dim_state_embedding = dim_state_embedding
self.reset_parameters()

def reset_parameters(self):
"""Reinitialize the parameters."""
self.distance_proj1.reset_parameters()
self.distance_proj2.reset_parameters()
self.distance_proj3.reset_parameters()
if self.layer_state_embedding is not None:
if self.dim_state_embedding > 0:
self.layer_state_embedding.reset_parameters()
if self.dim_state_feats is not None:
for layer in self.layer_state_mlp:
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
self.emb.reset_parameters()
self.emb2.reset_parameters()
self.emb3.reset_parameters()
Expand Down Expand Up @@ -248,17 +267,14 @@ def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None):
W2,
W3,
)
state_feat = None
if self.include_state is True:
if self.ntypes_state and self.dim_state_embedding is not None:
state_feat = self.layer_state_embedding(state_attr)
elif self.dim_state_feats is not None:
state_attr = torch.unsqueeze(state_attr, 0)
state_embed = MLP([state_attr.shape[-1], self.dim_state_feats], activation=self.activation)
state_feat = state_embed(state_attr.to(matgl.float_th))
else:
state_feat = state_attr
else:
state_feat = None
state_feat = self.layer_state_mlp(state_attr.to(matgl.float_th))

edge_feat = self.emb3(edge_attr)
with g.local_scope():
g.edata["Iij"] = Iij
Expand All @@ -282,7 +298,12 @@ def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None):
norm = self.act(linear_scalar(norm))
norm = norm.reshape(norm.shape[0], self.units, 3)
scalars, skew_metrices, traceless_tensors = new_radial_tensor(
scalars, skew_metrices, traceless_tensors, norm[..., 0], norm[..., 1], norm[..., 2]
scalars,
skew_metrices,
traceless_tensors,
norm[..., 0],
norm[..., 1],
norm[..., 2],
)
X = scalars + skew_metrices + traceless_tensors

Expand Down Expand Up @@ -351,7 +372,10 @@ def forward(
x_neighbors = self.embedding(z)
msg = W * x_neighbors.index_select(0, edge_index[1])
x_neighbors = torch.zeros(
node_feat.shape[0], node_feat.shape[1], dtype=node_feat.dtype, device=node_feat.device
node_feat.shape[0],
node_feat.shape[1],
dtype=node_feat.dtype,
device=node_feat.device,
).index_add(0, edge_index[0], msg)
x_neighbors = self.combine(torch.cat([node_feat, x_neighbors], dim=1))
return x_neighbors
106 changes: 100 additions & 6 deletions tests/layers/test_core_and_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@ def test_gated_equivairant_block(self, x):
vector_input = torch.randn(4, 3, 10)

output_scalar, output_vector = GatedEquivariantBlock(
n_sin=10, n_vin=10, n_sout=10, n_vout=10, n_hidden=10, activation=nn.SiLU(), sactivation=nn.SiLU()
n_sin=10,
n_vin=10,
n_sout=10,
n_vout=10,
n_hidden=10,
activation=nn.SiLU(),
sactivation=nn.SiLU(),
)([scaler_input, vector_input])

assert output_scalar.shape == (4, 10)
Expand All @@ -51,7 +57,47 @@ def test_build_gated_equivariant_mlp(selfself, x):
scaler_input = x
vector_input = torch.randn(4, 3, 10)
net = build_gated_equivariant_mlp( # type: ignore
n_in=10, n_out=1, n_hidden=10, n_layers=2, activation=nn.SiLU(), sactivation=nn.SiLU()
n_in=10,
n_out=1,
n_hidden=10,
n_layers=2,
activation=nn.SiLU(),
sactivation=nn.SiLU(),
)
output_scalar, output_vector = net([scaler_input, vector_input])

assert output_scalar.shape == (4, 1)
assert torch.squeeze(output_vector).shape == (4, 3)
# without n_hidden
net = build_gated_equivariant_mlp(n_in=10, n_out=1, n_layers=2, activation=nn.SiLU(), sactivation=nn.SiLU())
output_scalar, output_vector = net([scaler_input, vector_input])

assert output_scalar.shape == (4, 1)
assert torch.squeeze(output_vector).shape == (4, 3)
# with n_gating_hidden
net = build_gated_equivariant_mlp(
n_in=10,
n_out=1,
n_hidden=10,
n_layers=2,
n_gating_hidden=2,
activation=nn.SiLU(),
sactivation=nn.SiLU(),
)
output_scalar, output_vector = net([scaler_input, vector_input])

assert output_scalar.shape == (4, 1)
assert torch.squeeze(output_vector).shape == (4, 3)

# with sequence n_gating_hidden
net = build_gated_equivariant_mlp(
n_in=10,
n_out=1,
n_hidden=10,
n_layers=2,
n_gating_hidden=[10, 10],
activation=nn.SiLU(),
sactivation=nn.SiLU(),
)
output_scalar, output_vector = net([scaler_input, vector_input])

Expand Down Expand Up @@ -107,7 +153,11 @@ def test_embedding(self, graph_Mo):
assert [state_feat.size(dim=0), state_feat.size(dim=1)] == [1, 16]
# without any state feature
embed4 = EmbeddingBlock(
degree_rbf=9, dim_node_embedding=16, dim_edge_embedding=16, ntypes_node=2, activation=nn.SiLU()
degree_rbf=9,
dim_node_embedding=16,
dim_edge_embedding=16,
ntypes_node=2,
activation=nn.SiLU(),
)
node_feat, edge_feat, state_feat = embed4(
node_attr, edge_attr, torch.tensor([0.0, 0.0])
Expand All @@ -131,15 +181,55 @@ def test_tensor_embedding(self, graph_Mo):
s1, g1, state1 = graph_Mo
bond_expansion = BondExpansion(rbf_type="SphericalBessel", max_n=3, max_l=3, cutoff=4.0, smooth=True)
g1.edata["edge_attr"] = bond_expansion(g1.edata["bond_dist"])

# without state
tensor_embedding = TensorEmbedding(
units=64, degree_rbf=3, activation=nn.SiLU(), ntypes_node=1, cutoff=5.0, dtype=matgl.float_th
units=64,
degree_rbf=3,
activation=nn.SiLU(),
ntypes_node=1,
cutoff=5.0,
dtype=matgl.float_th,
)

X, edge_feat, state_feat = tensor_embedding(g1, state1)

assert [X.shape[0], X.shape[1], X.shape[2], X.shape[3]] == [2, 64, 3, 3]
assert [edge_feat.shape[0], edge_feat.shape[1]] == [52, 64]
# with state embedding
tensor_embedding = TensorEmbedding(
units=64,
degree_rbf=3,
activation=nn.SiLU(),
ntypes_node=1,
cutoff=5.0,
dtype=matgl.float_th,
ntypes_state=2,
include_state=True,
dim_state_embedding=16,
)

X, edge_feat, state_feat = tensor_embedding(g1, torch.tensor([1]))

assert [X.shape[0], X.shape[1], X.shape[2], X.shape[3]] == [2, 64, 3, 3]
assert [edge_feat.shape[0], edge_feat.shape[1]] == [52, 64]
assert [state_feat.shape[0], state_feat.shape[1]] == [1, 16]

# with state MLP
tensor_embedding = TensorEmbedding(
units=64,
degree_rbf=3,
activation=nn.SiLU(),
ntypes_node=1,
cutoff=5.0,
dtype=matgl.float_th,
include_state=True,
dim_state_feats=16,
)
X, edge_feat, state_feat = tensor_embedding(g1, torch.tensor(state1))

assert [X.shape[0], X.shape[1], X.shape[2], X.shape[3]] == [2, 64, 3, 3]
assert [edge_feat.shape[0], edge_feat.shape[1]] == [52, 64]
assert [state_feat.shape[0], state_feat.shape[1]] == [1, 16]

def test_neighbor_embedding(self, graph_Mo):
s1, g1, state1 = graph_Mo
Expand All @@ -151,7 +241,11 @@ def test_neighbor_embedding(self, graph_Mo):
g1.ndata["node_feat"] = torch.rand(2, 64)

x = embedding(
g1.ndata["node_type"], g1.ndata["node_feat"], g1.edges(), g1.edata["bond_dist"], g1.edata["edge_attr"]
g1.ndata["node_type"],
g1.ndata["node_feat"],
g1.edges(),
g1.edata["bond_dist"],
g1.edata["edge_attr"],
)

assert [x.shape[0], x.shape[1]] == [2, 64]
6 changes: 4 additions & 2 deletions tests/layers/test_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ def test_global_pool(self, graph_MoS):
g1.edata["edge_feat"] = edge_feat
g_attr = torch.zeros((1, 16), dtype=matgl.float_th)
pool = GlobalPool(feat_size=16)
g_feat = pool(g1, node_feat, g_attr)
g_feat, attention_weights = pool(g1, node_feat, g_attr, True)
assert [g_feat.shape[0], g_feat.shape[1]] == [1, 16]
assert [attention_weights.shape[0]] == [2]

def test_attentive_fp(self, graph_MoS):
s, g1, state = graph_MoS
Expand All @@ -181,5 +182,6 @@ def test_attentive_fp(self, graph_MoS):
g1.ndata["node_feat"] = node_feat
g1.edata["edge_feat"] = edge_feat
pool = AttentiveFPReadout(feat_size=16, num_timesteps=2)
g_feat = pool(g1, node_feat)
g_feat, attention_weights = pool(g1, node_feat, True)
assert [g_feat.shape[0], g_feat.shape[1]] == [1, 16]
assert [attention_weights.shape[0], attention_weights.shape[1]] == [2, 2]

0 comments on commit 12bef9d

Please sign in to comment.