Skip to content

Commit

Permalink
Small Refactoring for include_state (#239)
Browse files Browse the repository at this point in the history
* model version for Potential class is added

* model version for Potential class is modified

* Enable the smooth version of Spherical Bessel function in TensorNet

* max_n, max_l for SphericalBessel radial basis functions are included in TensorNet class

* adding united tests for improving the coverage score

* little clean up in _so3.py and so3.py

* remove unnecessary data storage in dgl graphs

* update pymatgen version to fix the bug

* refactor all include_states into include_state for consistency

* change include_states into include_state in test_graph_conv.py
  • Loading branch information
kenko911 authored Mar 7, 2024
1 parent 08a0338 commit 33ca510
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
24 changes: 12 additions & 12 deletions src/matgl/layers/_graph_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,23 +238,23 @@ class M3GNetGraphConv(Module):

def __init__(
self,
include_states: bool,
include_state: bool,
edge_update_func: Module,
edge_weight_func: Module,
node_update_func: Module,
node_weight_func: Module,
state_update_func: Module | None,
):
"""Parameters:
include_states (bool): Whether including state
include_state (bool): Whether including state
edge_update_func (Module): Update function for edges (Eq. 4)
edge_weight_func (Module): Weight function for radial basis functions (Eq. 4)
node_update_func (Module): Update function for nodes (Eq. 5)
node_weight_func (Module): Weight function for radial basis functions (Eq. 5)
state_update_func (Module): Update function for state feats (Eq. 6).
"""
super().__init__()
self.include_states = include_states
self.include_state = include_state
self.edge_update_func = edge_update_func
self.edge_weight_func = edge_weight_func
self.node_update_func = node_update_func
Expand All @@ -264,7 +264,7 @@ def __init__(
@staticmethod
def from_dims(
degree,
include_states,
include_state,
edge_dims: list[int],
node_dims: list[int],
state_dims: list[int] | None,
Expand All @@ -274,7 +274,7 @@ def from_dims(
Args:
degree (int): max_n*max_l
include_states (bool): whether including state or not
include_state (bool): whether including state or not
edge_dims (list): NN architecture for edge update function
node_dims (list): NN architecture for node update function
state_dims (list): NN architecture for state update function
Expand All @@ -288,9 +288,9 @@ def from_dims(

node_update_func = GatedMLP(in_feats=node_dims[0], dims=node_dims[1:])
node_weight_func = nn.Linear(in_features=degree, out_features=node_dims[-1], bias=False)
attr_update_func = MLP(state_dims, activation, activate_last=True) if include_states else None # type: ignore
attr_update_func = MLP(state_dims, activation, activate_last=True) if include_state else None # type: ignore
return M3GNetGraphConv(
include_states, edge_update_func, edge_weight_func, node_update_func, node_weight_func, attr_update_func
include_state, edge_update_func, edge_weight_func, node_update_func, node_weight_func, attr_update_func
)

def _edge_udf(self, edges: dgl.udf.EdgeBatch):
Expand All @@ -305,11 +305,11 @@ def _edge_udf(self, edges: dgl.udf.EdgeBatch):
vi = edges.src["v"]
vj = edges.dst["v"]
u = None
if self.include_states:
if self.include_state:
u = edges.src["u"]
eij = edges.data.pop("e")
rbf = edges.data["rbf"]
inputs = torch.hstack([vi, vj, eij, u]) if self.include_states else torch.hstack([vi, vj, eij])
inputs = torch.hstack([vi, vj, eij, u]) if self.include_state else torch.hstack([vi, vj, eij])
mij = {"mij": self.edge_update_func(inputs) * self.edge_weight_func(rbf)}
return mij

Expand Down Expand Up @@ -342,7 +342,7 @@ def node_update_(self, graph: dgl.DGLGraph, state_feat: Tensor) -> Tensor:
dst_id = graph.edges()[1]
vj = graph.ndata["v"][dst_id]
rbf = graph.edata["rbf"]
if self.include_states:
if self.include_state:
u = dgl.broadcast_edges(graph, state_feat)
inputs = torch.hstack([vi, vj, eij, u])
else:
Expand Down Expand Up @@ -390,14 +390,14 @@ def forward(
with graph.local_scope():
graph.edata["e"] = edge_feat
graph.ndata["v"] = node_feat
if self.include_states:
if self.include_state:
graph.ndata["u"] = dgl.broadcast_nodes(graph, state_feat)

edge_update = self.edge_update_(graph)
graph.edata["e"] = edge_feat + edge_update
node_update = self.node_update_(graph, state_feat)
graph.ndata["v"] = node_feat + node_update
if self.include_states:
if self.include_state:
state_feat = self.state_update_(graph, state_feat)

return edge_feat + edge_update, node_feat + node_update, state_feat
Expand Down
4 changes: 2 additions & 2 deletions src/matgl/models/_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def __init__(
self.units = units
self.cutoff = cutoff
self.threebody_cutoff = threebody_cutoff
self.include_states = include_state
self.include_state = include_state
self.task_type = task_type
self.is_intensive = is_intensive

Expand Down Expand Up @@ -280,7 +280,7 @@ def forward(
g.edata["edge_feat"] = edge_feat
if self.is_intensive:
field_vec = self.readout(g)
readout_vec = torch.hstack([field_vec, state_feat]) if self.include_states else field_vec # type: ignore
readout_vec = torch.hstack([field_vec, state_feat]) if self.include_state else field_vec # type: ignore
fea_dict["readout"] = readout_vec
output = self.final_layer(readout_vec)
if self.task_type == "classification":
Expand Down
2 changes: 1 addition & 1 deletion tests/layers/test_graph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_m3gnet_graph_conv(self, graph_MoS):
degree = max_n * max_l
conv = M3GNetGraphConv.from_dims(
degree=degree,
include_states=True,
include_state=True,
edge_dims=[edge_in, *conv_hiddens, num_edge_feats],
node_dims=[node_in, *conv_hiddens, num_node_feats],
state_dims=[state_in, *conv_hiddens, num_state_feats],
Expand Down

0 comments on commit 33ca510

Please sign in to comment.