From 33ca5105ae76ae3a23958c3125c224c5988e5c90 Mon Sep 17 00:00:00 2001 From: Tsz Wai Ko <47970742+kenko911@users.noreply.github.com> Date: Thu, 7 Mar 2024 12:38:58 -0600 Subject: [PATCH] Small Refactoring for include_state (#239) * model version for Potential class is added * model version for Potential class is modified * Enable the smooth version of Spherical Bessel function in TensorNet * max_n, max_l for SphericalBessel radial basis functions are included in TensorNet class * adding united tests for improving the coverage score * little clean up in _so3.py and so3.py * remove unnecessary data storage in dgl graphs * update pymatgen version to fix the bug * refactor all include_states into include_state for consistency * change include_states into include_state in test_graph_conv.py --- src/matgl/layers/_graph_convolution.py | 24 ++++++++++++------------ src/matgl/models/_m3gnet.py | 4 ++-- tests/layers/test_graph_conv.py | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/matgl/layers/_graph_convolution.py b/src/matgl/layers/_graph_convolution.py index 7cac9548..accca1be 100644 --- a/src/matgl/layers/_graph_convolution.py +++ b/src/matgl/layers/_graph_convolution.py @@ -238,7 +238,7 @@ class M3GNetGraphConv(Module): def __init__( self, - include_states: bool, + include_state: bool, edge_update_func: Module, edge_weight_func: Module, node_update_func: Module, @@ -246,7 +246,7 @@ def __init__( state_update_func: Module | None, ): """Parameters: - include_states (bool): Whether including state + include_state (bool): Whether including state edge_update_func (Module): Update function for edges (Eq. 4) edge_weight_func (Module): Weight function for radial basis functions (Eq. 4) node_update_func (Module): Update function for nodes (Eq. 5) @@ -254,7 +254,7 @@ def __init__( state_update_func (Module): Update function for state feats (Eq. 6). """ super().__init__() - self.include_states = include_states + self.include_state = include_state self.edge_update_func = edge_update_func self.edge_weight_func = edge_weight_func self.node_update_func = node_update_func @@ -264,7 +264,7 @@ def __init__( @staticmethod def from_dims( degree, - include_states, + include_state, edge_dims: list[int], node_dims: list[int], state_dims: list[int] | None, @@ -274,7 +274,7 @@ def from_dims( Args: degree (int): max_n*max_l - include_states (bool): whether including state or not + include_state (bool): whether including state or not edge_dims (list): NN architecture for edge update function node_dims (list): NN architecture for node update function state_dims (list): NN architecture for state update function @@ -288,9 +288,9 @@ def from_dims( node_update_func = GatedMLP(in_feats=node_dims[0], dims=node_dims[1:]) node_weight_func = nn.Linear(in_features=degree, out_features=node_dims[-1], bias=False) - attr_update_func = MLP(state_dims, activation, activate_last=True) if include_states else None # type: ignore + attr_update_func = MLP(state_dims, activation, activate_last=True) if include_state else None # type: ignore return M3GNetGraphConv( - include_states, edge_update_func, edge_weight_func, node_update_func, node_weight_func, attr_update_func + include_state, edge_update_func, edge_weight_func, node_update_func, node_weight_func, attr_update_func ) def _edge_udf(self, edges: dgl.udf.EdgeBatch): @@ -305,11 +305,11 @@ def _edge_udf(self, edges: dgl.udf.EdgeBatch): vi = edges.src["v"] vj = edges.dst["v"] u = None - if self.include_states: + if self.include_state: u = edges.src["u"] eij = edges.data.pop("e") rbf = edges.data["rbf"] - inputs = torch.hstack([vi, vj, eij, u]) if self.include_states else torch.hstack([vi, vj, eij]) + inputs = torch.hstack([vi, vj, eij, u]) if self.include_state else torch.hstack([vi, vj, eij]) mij = {"mij": self.edge_update_func(inputs) * self.edge_weight_func(rbf)} return mij @@ -342,7 +342,7 @@ def node_update_(self, graph: dgl.DGLGraph, state_feat: Tensor) -> Tensor: dst_id = graph.edges()[1] vj = graph.ndata["v"][dst_id] rbf = graph.edata["rbf"] - if self.include_states: + if self.include_state: u = dgl.broadcast_edges(graph, state_feat) inputs = torch.hstack([vi, vj, eij, u]) else: @@ -390,14 +390,14 @@ def forward( with graph.local_scope(): graph.edata["e"] = edge_feat graph.ndata["v"] = node_feat - if self.include_states: + if self.include_state: graph.ndata["u"] = dgl.broadcast_nodes(graph, state_feat) edge_update = self.edge_update_(graph) graph.edata["e"] = edge_feat + edge_update node_update = self.node_update_(graph, state_feat) graph.ndata["v"] = node_feat + node_update - if self.include_states: + if self.include_state: state_feat = self.state_update_(graph, state_feat) return edge_feat + edge_update, node_feat + node_update, state_feat diff --git a/src/matgl/models/_m3gnet.py b/src/matgl/models/_m3gnet.py index 458718bd..1c7cd936 100644 --- a/src/matgl/models/_m3gnet.py +++ b/src/matgl/models/_m3gnet.py @@ -214,7 +214,7 @@ def __init__( self.units = units self.cutoff = cutoff self.threebody_cutoff = threebody_cutoff - self.include_states = include_state + self.include_state = include_state self.task_type = task_type self.is_intensive = is_intensive @@ -280,7 +280,7 @@ def forward( g.edata["edge_feat"] = edge_feat if self.is_intensive: field_vec = self.readout(g) - readout_vec = torch.hstack([field_vec, state_feat]) if self.include_states else field_vec # type: ignore + readout_vec = torch.hstack([field_vec, state_feat]) if self.include_state else field_vec # type: ignore fea_dict["readout"] = readout_vec output = self.final_layer(readout_vec) if self.task_type == "classification": diff --git a/tests/layers/test_graph_conv.py b/tests/layers/test_graph_conv.py index 9bd63b68..80107fe9 100644 --- a/tests/layers/test_graph_conv.py +++ b/tests/layers/test_graph_conv.py @@ -113,7 +113,7 @@ def test_m3gnet_graph_conv(self, graph_MoS): degree = max_n * max_l conv = M3GNetGraphConv.from_dims( degree=degree, - include_states=True, + include_state=True, edge_dims=[edge_in, *conv_hiddens, num_edge_feats], node_dims=[node_in, *conv_hiddens, num_node_feats], state_dims=[state_in, *conv_hiddens, num_state_feats],