Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure the state attr from molecular graph is consistent with matgl.float_th and include linear layer in TensorNet to match the original implementations #244

Merged
merged 25 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b41dc07
model version for Potential class is added
kenko911 Nov 17, 2023
04634b1
model version for Potential class is modified
kenko911 Nov 17, 2023
54b925b
Merge branch 'materialsvirtuallab:main' into main
kenko911 Nov 26, 2023
fdd348f
Merge branch 'main' of https://github.com/kenko911/matgl
kenko911 Nov 26, 2023
82efd10
Merge branch 'materialsvirtuallab:main' into main
kenko911 Dec 13, 2023
ada5eff
Merge branch 'materialsvirtuallab:main' into main
kenko911 Feb 5, 2024
cc24b45
Merge branch 'main' of https://github.com/kenko911/matgl
kenko911 Feb 14, 2024
f59436e
Merge branch 'materialsvirtuallab:main' into main
kenko911 Feb 14, 2024
bea738f
Merge branch 'main' of https://github.com/kenko911/matgl
kenko911 Feb 14, 2024
d4fe6f2
Merge branch 'materialsvirtuallab:main' into main
kenko911 Feb 16, 2024
b0f55c4
Merge branch 'main' of https://github.com/kenko911/matgl
kenko911 Feb 16, 2024
4cd717f
Enable the smooth version of Spherical Bessel function in TensorNet
kenko911 Feb 16, 2024
ca3f952
max_n, max_l for SphericalBessel radial basis functions are included …
kenko911 Feb 16, 2024
0b31ac5
adding united tests for improving the coverage score
kenko911 Feb 17, 2024
5ada1fc
Merge branch 'materialsvirtuallab:main' into main
kenko911 Feb 18, 2024
95a0f33
little clean up in _so3.py and so3.py
kenko911 Feb 19, 2024
0e77e3a
Merge branch 'materialsvirtuallab:main' into main
kenko911 Feb 20, 2024
2c5c23c
Merge branch 'materialsvirtuallab:main' into main
kenko911 Mar 1, 2024
2893736
remove unnecessary data storage in dgl graphs
kenko911 Mar 1, 2024
f89e42e
update pymatgen version to fix the bug
kenko911 Mar 3, 2024
37319e5
refractor all include_states into include_state for consistency
kenko911 Mar 7, 2024
391d3bf
change include_states into include_state in test_graph_conv.py
kenko911 Mar 7, 2024
904da58
Merge branch 'materialsvirtuallab:main' into main
kenko911 Mar 7, 2024
53415de
Merge branch 'materialsvirtuallab:main' into main
kenko911 Mar 29, 2024
065bfc2
Ensure the state attr from molecule graph is consistent with matgl.fl…
kenko911 Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/matgl/graph/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def process(self):
if self.graph_labels is not None:
state_attrs = torch.tensor(self.graph_labels).long()
else:
state_attrs = torch.tensor(np.array(state_attrs))
state_attrs = torch.tensor(np.array(state_attrs), dtype=matgl.float_th)

if self.clear_processed:
del self.structures
Expand Down
6 changes: 4 additions & 2 deletions src/matgl/models/_tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ def __init__(
)

self.out_norm = nn.LayerNorm(3 * units, dtype=dtype)
self.linear = nn.Linear(3 * units, units, dtype=dtype)
if is_intensive:
input_feats = 3 * units if field == "node_feat" else units
input_feats = units
if readout_type == "set2set":
self.readout = Set2SetReadOut(
in_feats=input_feats, n_iters=niters_set2set, n_layers=nlayers_set2set, field=field
Expand All @@ -203,7 +204,7 @@ def __init__(
if task_type == "classification":
raise ValueError("Classification task cannot be extensive.")
self.final_layer = WeightedReadOut(
in_feats=3 * units,
in_feats=units,
dims=[units, units],
num_targets=ntargets, # type: ignore
)
Expand Down Expand Up @@ -247,6 +248,7 @@ def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, **kwa

x = torch.cat((tensor_norm(scalars), tensor_norm(skew_metrices), tensor_norm(traceless_tensors)), dim=-1)
x = self.out_norm(x)
x = self.linear(x)

g.ndata["node_feat"] = x
if self.is_intensive:
Expand Down
Loading