Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CHGNet-matgl implementation #242

Merged
merged 351 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
351 commits
Select commit Hold shift + click to select a range
697c443
ENH: fixing chgnet dset
lbluque Aug 13, 2023
62693b2
MAINT: create tensors in lg device
lbluque Aug 13, 2023
f2ed897
MAINT: use register buffer in Potential and LightningPotential
lbluque Aug 13, 2023
a54ff6e
MAIN: rename chgnet graph feats
lbluque Aug 13, 2023
b865bbb
FIX: clamp cos values to -1, 1 with eps
lbluque Aug 13, 2023
22d151f
Merge branch 'main' into chgnet
lbluque Aug 14, 2023
1de8ef6
ENH: start implementing chgnetdset
lbluque Aug 9, 2023
8300daa
Fix loading graphs
lbluque Aug 14, 2023
e48095e
use dgl path attrs in chgnet dataset
lbluque Aug 14, 2023
2a3cb4e
TST: add chgnetdataset test and fix errors
lbluque Aug 14, 2023
f4dad3b
Merge branch 'dataset' into chgnet
lbluque Aug 14, 2023
1e41da9
TST assert that unnormalized predictions are not the same
lbluque Aug 14, 2023
3e8698e
Merge branch 'main' into chgnet
lbluque Aug 15, 2023
1b4ade3
TST: clamp cos values to -1, 1 with eps in tests
lbluque Aug 16, 2023
a12015e
ENH: use torch.nan for None magmoms
lbluque Aug 16, 2023
bf47998
BUG: fix setting lg node data
lbluque Aug 18, 2023
9480c05
use no_grad in directed line graph
lbluque Aug 18, 2023
5437861
FIX: set lg data using num nodes
lbluque Aug 21, 2023
0305df3
TST: test up to 4 decimals
lbluque Aug 21, 2023
18c1eec
Merge branch 'main' of https://github.com/materialsvirtuallab/matgl i…
lbluque Aug 21, 2023
30d53dc
MAINT: update to renamed DEFAULT_ELEMENTS
lbluque Aug 21, 2023
6fb24d0
FIX: directed lg compatibility
lbluque Aug 22, 2023
b35e4c5
maint: update to new dataset interface
lbluque Aug 22, 2023
2415330
MAINT: update to new dataset interface
lbluque Aug 23, 2023
27a575b
TST: fix graph test
lbluque Aug 23, 2023
57e5681
MAINT: minor edit in directed line graph
lbluque Aug 23, 2023
1061aec
Merge branch 'main' into chgnet
lbluque Aug 23, 2023
374cbd0
Merge branch 'main' into chgnet
lbluque Aug 24, 2023
382b835
update to use dtype interface
lbluque Aug 24, 2023
16c7f29
Merge branch 'main' into chgnet
lbluque Aug 24, 2023
f2e14de
add tol to threebody cutoff
lbluque Aug 23, 2023
8f70a1c
add tol to threebody cutoff
lbluque Aug 23, 2023
507074f
FiX: remove tol and set pbc_offshift to float64
lbluque Aug 23, 2023
33a5a42
ENH: chunked chgnet dataset
lbluque Aug 24, 2023
688b952
remove state attr in has_cache
lbluque Aug 24, 2023
c4611cc
fix chunk_sizes
lbluque Aug 24, 2023
ee5a4e6
trange when loading indices
lbluque Aug 24, 2023
0b1c2a1
singular keys in collate
lbluque Aug 24, 2023
07ae295
hard code label keys
lbluque Aug 24, 2023
e1b3593
run pre-commit
lbluque Aug 24, 2023
5929716
change chgnet default elements
lbluque Aug 24, 2023
27f458e
FIX: create nan tensor for missing magmoms
lbluque Aug 24, 2023
a4a4b9d
add tol to threebody cutoff
lbluque Aug 23, 2023
f29c837
add tol to threebody cutoff
lbluque Aug 23, 2023
10bdea9
FiX: remove tol and set pbc_offshift to float64
lbluque Aug 23, 2023
0506078
ENH: chunked chgnet dataset
lbluque Aug 24, 2023
77cbcc6
remove state attr in has_cache
lbluque Aug 24, 2023
f566459
fix chunk_sizes
lbluque Aug 24, 2023
9e1e326
trange when loading indices
lbluque Aug 24, 2023
2cfbca8
singular keys in collate
lbluque Aug 24, 2023
63a895f
hard code label keys
lbluque Aug 24, 2023
fa50dab
run pre-commit
lbluque Aug 24, 2023
13e8319
change chgnet default elements
lbluque Aug 24, 2023
971135e
Merge remote-tracking branch 'origin/chgnet_train' into chgnet_train
lbluque Aug 24, 2023
9c1c741
FIX: nan tensor shape
lbluque Aug 24, 2023
644ea71
Merge branch 'chgnet' into chgnet_train
lbluque Aug 24, 2023
b02e307
FIX: allow skipping nan tensors
lbluque Aug 24, 2023
4fa29fb
add xavier normal and update chunked dataset
lbluque Aug 26, 2023
2546df5
fix getitem
lbluque Aug 26, 2023
013c8d6
fix getitem
lbluque Aug 26, 2023
bed14de
fix getitem
lbluque Aug 26, 2023
a97a65e
fix getitem
lbluque Aug 26, 2023
882da3a
fix getitem
lbluque Aug 26, 2023
2f31e23
fix getitem
lbluque Aug 26, 2023
aa765b2
huber loss
lbluque Aug 31, 2023
7df9c0d
MAINT: use torch instead of numpy
lbluque Sep 1, 2023
f913e0a
MAINT: keep onehot matrix as attribute
lbluque Sep 1, 2023
414c924
Merge branch 'main' into chgnet_train
lbluque Sep 1, 2023
e2facd0
MAINT: remove unnecessary statements
lbluque Sep 1, 2023
3794af0
MAINT: remove unnecessary statements
lbluque Sep 1, 2023
b84f24c
MAINT: onehot as buffer
lbluque Sep 1, 2023
2fb9855
MAINT: property offset as buffer
lbluque Sep 1, 2023
92ad929
MAINT: onehot as buffer
lbluque Sep 1, 2023
01f148b
MAINT: property offset as buffer
lbluque Sep 1, 2023
4932972
change order in init
lbluque Sep 1, 2023
5c3ae23
TST update tests
lbluque Sep 1, 2023
7c64f44
ENH use lstsq to avoid constructing full normal eqs
lbluque Sep 1, 2023
bdf89b0
change order in init
lbluque Sep 1, 2023
9053ac9
TST update tests
lbluque Sep 1, 2023
24664fb
ENH use lstsq to avoid constructing full normal eqs
lbluque Sep 1, 2023
c524162
remove numpy import
lbluque Sep 1, 2023
729207d
remove print
lbluque Sep 1, 2023
16c262c
Merge branch 'main' into main
shyuep Sep 1, 2023
0740986
STY: fix lint
lbluque Sep 1, 2023
d330699
FIX: backwards compat with pre-trained models
lbluque Sep 1, 2023
a5032f7
ENH: raise load_model error from baseexception
lbluque Sep 1, 2023
22a2e4e
TST: fix atomref tests
lbluque Sep 1, 2023
2320580
STY: ruff
lbluque Sep 1, 2023
cd58dcf
Merge branch 'main' into chgnet_train
lbluque Sep 1, 2023
dbd8fdf
FIX: use tuple in isinstance for 3.9 compat
lbluque Sep 1, 2023
9a99106
Merge branch 'main' into chgnet_train
lbluque Sep 1, 2023
adae431
remove numpy import
lbluque Sep 1, 2023
6265749
STY: ruff
lbluque Sep 1, 2023
42baa27
Merge branch 'main' into chgnet
lbluque Sep 1, 2023
fd76452
Merge branch 'chgnet' into chgnet_train
lbluque Sep 1, 2023
f86a51d
remove numpy import
lbluque Sep 1, 2023
2a35a85
STY: ruff
lbluque Sep 1, 2023
e74c425
Merge branch 'main' into chgnet
lbluque Sep 8, 2023
55ad683
Merge branch 'chgnet' into chgnet_train
lbluque Sep 8, 2023
938f61b
remove assert in compat (fails for some batched graphs)
lbluque Sep 8, 2023
2807de6
ENH: messy graphnorm mess
lbluque Sep 11, 2023
5553fe0
FIX: fix allow missing labels
lbluque Sep 11, 2023
a91a789
use lg num_nodes() directly
lbluque Sep 12, 2023
eb369e1
use lg num_nodes() directly
lbluque Sep 12, 2023
987b94a
do not assert
lbluque Sep 12, 2023
f19e913
FIX: fix ensuring line graph for bonds right at cutoff
lbluque Sep 13, 2023
80b53a5
remove numpy import
lbluque Sep 1, 2023
e874067
STY: ruff
lbluque Sep 1, 2023
f43c8ee
Merge remote-tracking branch 'origin/main'
lbluque Sep 14, 2023
1fa36e3
Merge branch 'main' into chgnet
lbluque Sep 14, 2023
d851d84
Remove wheel and release.
Sep 11, 2023
c228311
Bump pymatgen from 2023.9.2 to 2023.9.10 (#162)
dependabot[bot] Sep 11, 2023
e6dcd03
Add united test for trainer.test and description in the example (#165)
kenko911 Sep 13, 2023
ebbbb33
Merge branch 'chgnet' into chgnet_train
lbluque Sep 14, 2023
0225624
ENH: allow skipping label keys
lbluque Sep 12, 2023
239488d
use tuple
lbluque Sep 15, 2023
5f03aa9
ENH: allow skipping label keys
lbluque Sep 12, 2023
fcbf03b
use tuple
lbluque Sep 15, 2023
fc9efbe
use skip labels in chunked dataset
lbluque Sep 15, 2023
1deadbf
add empty axis to magmoms
lbluque Sep 15, 2023
3c15846
add empty axis to magmoms
lbluque Sep 15, 2023
0e4e2a4
ENH: graph norm implementation
lbluque Sep 16, 2023
0132791
TST: add graph_norm test
lbluque Sep 16, 2023
48061a1
remove adding extra axis to magmoms
lbluque Sep 16, 2023
1695d60
remove adding extra axis to magmoms
lbluque Sep 16, 2023
3192fd7
add skip label keys to chunked dataset
lbluque Sep 18, 2023
3d0a4a3
fix chunked dset
lbluque Sep 18, 2023
d62bdf4
add OOM dataset
lbluque Sep 19, 2023
ed9375d
len w state_attr
lbluque Sep 19, 2023
efa77e5
int idx
lbluque Sep 19, 2023
31af30d
increase compatibility tol
lbluque Sep 19, 2023
c1b6309
lintings
lbluque Sep 19, 2023
ac87ee5
STY: fix some linting errors
lbluque Sep 19, 2023
fb7c31d
Merge branch 'chgnet' into chgnet_train
lbluque Sep 19, 2023
6193abc
STY: fix mypy errors
lbluque Sep 19, 2023
9b37f34
remove numpy import
lbluque Sep 1, 2023
3f775a3
STY: ruff
lbluque Sep 1, 2023
429c89b
remove numpy import
lbluque Sep 1, 2023
53bcc3a
STY: ruff
lbluque Sep 1, 2023
95ae563
TYP: use Sequence instead of list
lbluque Sep 19, 2023
4f2d4da
lint
lbluque Sep 19, 2023
386e3b5
MAINT: use sequential in MLP
lbluque Sep 19, 2023
31d14dd
ENH: norm gated MLP
lbluque Sep 20, 2023
f46f782
Merge branch 'main' of https://github.com/materialsvirtuallab/matgl
lbluque Sep 21, 2023
e462a8a
MAINT: use sequential in MLP
lbluque Sep 19, 2023
fa62f07
store linear layers and activation separately in MLP
lbluque Sep 21, 2023
90e650a
use MLP in gated MLP
lbluque Sep 21, 2023
5796ecf
remove unnecessary Sequential
lbluque Sep 21, 2023
64dfa09
correct magmom training index!
lbluque Sep 25, 2023
9107a44
revert magmom index bc it was correct!
lbluque Sep 25, 2023
70909b0
Merge branch 'mlp' into graph_norm
lbluque Sep 26, 2023
fa0c2a4
ENH: graphnorm in mlp and gmlp
lbluque Sep 27, 2023
7f80d40
remove numpy import
lbluque Sep 1, 2023
714b8db
STY: ruff
lbluque Sep 1, 2023
1d90eb1
remove numpy import
lbluque Sep 1, 2023
e9ffc14
STY: ruff
lbluque Sep 1, 2023
c0c188a
FIX: remove repeated bond expansion
lbluque Sep 27, 2023
0fa1413
Merge branch 'main' into chgnet
lbluque Sep 27, 2023
485ef21
Merge branch 'chgnet' into chgnet_train
lbluque Sep 27, 2023
0c5860a
hack to load new state dicts in PL checkpoints
lbluque Sep 27, 2023
f61d570
allow site_wise loss options
lbluque Sep 28, 2023
80c7b16
only set grad enabled in forward
lbluque Sep 28, 2023
4d880f0
adapt core to allow normalization of different layers
lbluque Sep 29, 2023
5b91b97
Merge branch 'chgnet_train' of https://github.com/lbluque/matgl into …
lbluque Sep 29, 2023
2505fd7
Merge branch 'chgnet' of https://github.com/lbluque/matgl into chgnet
lbluque Sep 29, 2023
137391d
remove some TODOS
lbluque Sep 29, 2023
8089cf7
Merge branch 'graph_norm' into chgnet_norm
lbluque Sep 29, 2023
c846690
allow normalization in chgnet
lbluque Sep 29, 2023
ab08f0d
always normalize last
lbluque Sep 29, 2023
ba6d7bf
always normalize last
lbluque Sep 29, 2023
2bd649c
fix normalization inputs
lbluque Sep 29, 2023
8855677
fix mlp forward
lbluque Sep 29, 2023
c4a2a1a
fix mlp forward
lbluque Sep 29, 2023
a309278
messy norm
lbluque Sep 29, 2023
13613c9
allow norm kwargs and allow batching by edges or nodes in graphnorm
lbluque Sep 29, 2023
4b9ca21
test graphnorm
lbluque Sep 29, 2023
3afd275
Merge branch 'graph_norm' into chgnet_norm
lbluque Sep 29, 2023
f53e570
graph norm in chgnet
lbluque Sep 29, 2023
1a5d64e
allow layernorm in chgnet
lbluque Sep 29, 2023
0dc540e
allow layernorm in chgnet
lbluque Sep 29, 2023
f1663cb
rename args
lbluque Sep 29, 2023
d59411e
rename args
lbluque Sep 29, 2023
221035f
Merge branch 'main' of https://github.com/materialsvirtuallab/matgl i…
lbluque Sep 29, 2023
0dfc26d
Merge branch 'chgnet' into chgnet_norm
lbluque Sep 29, 2023
b2bffd0
Merge branch 'chgnet' into chgnet_train
lbluque Sep 29, 2023
71a339c
Merge branch 'chgnet_train' into chgnet_norm_train
lbluque Sep 29, 2023
c1426dd
fix mypy errors
lbluque Sep 29, 2023
862adfe
add tolerance in lg compatibility
lbluque Oct 2, 2023
23505ea
add tolerance in lg compatibility
lbluque Oct 2, 2023
3fc132a
raise runtime error for incompatible graph
lbluque Oct 2, 2023
7ec7234
raise runtime error for incompatible graph
lbluque Oct 2, 2023
b5f5abf
create tensors on same device in norm
lbluque Oct 2, 2023
f61379d
create tensors on same device in norm
lbluque Oct 2, 2023
7d0f5a6
Merge branch 'main' of https://github.com/materialsvirtuallab/matgl i…
lbluque Oct 2, 2023
932d1b7
Merge branch 'chgnet' into chgnet_norm
lbluque Oct 2, 2023
a5b0adc
Merge branch 'chgnet_norm' into chgnet_norm_train
lbluque Oct 2, 2023
0adcb01
Merge branch 'chgnet' into chgnet_train
lbluque Oct 2, 2023
14fcf86
Merge branch 'main' into chgnet
lbluque Oct 3, 2023
cded08f
update chgnet to use new line graph interface
lbluque Oct 3, 2023
567557f
update chgnet paper link
lbluque Oct 3, 2023
ae9156c
Merge branch 'chgnet' into chgnet_norm_train
lbluque Oct 3, 2023
f2e739f
Merge branch 'chgnet' into chgnet_train
lbluque Oct 3, 2023
635b153
Merge branch 'chgnet' into chgnet_norm
lbluque Oct 3, 2023
25bda3f
update line graph in dataset
lbluque Oct 3, 2023
13766ba
Merge branch 'chgnet' into chgnet_norm
lbluque Oct 3, 2023
a93dcb8
Merge branch 'chgnet' into chgnet_train
lbluque Oct 3, 2023
01e095f
Merge branch 'chgnet_train' into chgnet_norm_train
lbluque Oct 3, 2023
c34aa9e
no bias in output of conv layers
lbluque Oct 5, 2023
8c8dc38
Merge branch 'chgnet' into chgnet_train
lbluque Oct 5, 2023
7993f4f
Merge branch 'chgnet' into chgnet_norm
lbluque Oct 5, 2023
8f930e7
Merge branch 'chgnet_norm' into chgnet_norm_train
lbluque Oct 5, 2023
2146dc2
some docstrings
bowen-bd Oct 18, 2023
8545e3a
moved mlp_out from InteractionBlock to ConvFunctions and added non-li…
bowen-bd Oct 19, 2023
368efbb
Merge branch 'chgnet' into chgnet_train
bowen-bd Oct 21, 2023
aa0dbe5
fix typo
lbluque Oct 23, 2023
92c2e0b
Merge branch 'chgnet' into chgnet_train
lbluque Oct 23, 2023
0211e04
Merge branch 'chgnet' into chgnet_norm
lbluque Oct 23, 2023
536ad8b
moved out_layer to linear
bowen-bd Oct 31, 2023
51b3a74
solved bug
bowen-bd Oct 31, 2023
982e5e9
Merge branch 'chgnet' into chgnet_norm
lbluque Nov 7, 2023
5a3f200
Merge branch 'chgnet_norm' into chgnet_norm_train
lbluque Nov 7, 2023
c15d5c2
solved bug
bowen-bd Nov 10, 2023
2595c34
Merge branch 'chgnet_norm' into chgnet_norm_train
bowen-bd Nov 10, 2023
ec8e528
removed normalization from bondgraph layer
bowen-bd Nov 14, 2023
6404299
Merge branch 'chgnet_norm' into chgnet_norm_train
bowen-bd Nov 14, 2023
bfefe2b
uploaded pretrained model and modified ASE interface
bowen-bd Dec 13, 2023
f926e02
fix linting
bowen-bd Dec 13, 2023
452f056
updated stress calculation to match mvl matgl
bowen-bd Jan 12, 2024
ad49c71
fixed chgnet dataset by adding lattice
bowen-bd Jan 12, 2024
8257d8d
hot fix
bowen-bd Jan 12, 2024
26fbd9a
add frac_coords to pre-processed graphs
bowen-bd Jan 12, 2024
00c7a6d
hot fix
bowen-bd Jan 13, 2024
d6fce6d
solved bug
bowen-bd Jan 18, 2024
27ca01c
remove ignore model
bowen-bd Jan 25, 2024
a6f6088
add 11M model weights
bowen-bd Feb 14, 2024
7a5e3e5
renamed pretrained weights
bowen-bd Mar 1, 2024
8ab82c4
Adding CHGNet-matgl implementation
bowen-bd Mar 20, 2024
902b0e5
corrected texts and comments
bowen-bd Mar 20, 2024
d7a9a31
fix more texts
bowen-bd Mar 20, 2024
e52de99
more texts fixes
bowen-bd Mar 20, 2024
bcb71fd
refactor CHGNet path in test
bowen-bd Mar 20, 2024
340b9da
fixed linting
bowen-bd Mar 20, 2024
4a875e0
fixed texts
bowen-bd Mar 20, 2024
2c92d97
Merge remote-tracking branch 'bowen/main' into chgnet_final
lbluque Mar 29, 2024
392cc03
Merge pull request #1 from lbluque/chgnet_final
bowen-bd Mar 29, 2024
4985006
remove unused CHGNetDataset
bowen-bd Mar 29, 2024
a95c104
Merge branch 'main' into main
shyuep May 2, 2024
e035d1c
restructure matgl modules for CHGNet implementations
kenko911 May 5, 2024
cc80907
fix ruff
kenko911 May 6, 2024
49c5193
update model versioning for Potential class
kenko911 May 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions pretrained_models/CHGNet-MPtrj-2023.12.1-PES-2.7M/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Description

This model is a CHGNet universal potential trained from the Materials Project trajectory (MPtrj) dataset
that contains over 1.5 million structures with 89 elements.
This Matgl implementation has slight modification from original pytorch implementation by adding directed edge updates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace "pytorch" with "PyTorch" to adhere to the correct capitalization of the framework's name.

- original pytorch implementation
+ original PyTorch implementation

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
This Matgl implementation has slight modification from original pytorch implementation by adding directed edge updates.
This Matgl implementation has slight modification from original PyTorch implementation by adding directed edge updates.


# Training dataset

MPtrj-2022.9: Materials Project trajectory dataset that contains GGA and GGA+U static and relaxation calculations.
- Train-Val-Test splitting with mp-id: 0.9 - 0.5 - 0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For numerical ranges, consider using an en dash (–) instead of a hyphen (-) for improved readability. For example, "0.9–0.5–0.5" instead of "0.9 - 0.5 - 0.5".

- 0.9 - 0.5 - 0.5
+ 0.9–0.5–0.5

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
- Train-Val-Test splitting with mp-id: 0.9 - 0.5 - 0.5
- Train-Val-Test splitting with mp-id: 0.90.50.5

- Train set size: 1419861
- Validation set size: 79719
- Test set size: 79182

# Performance metrics
## Training and validation errors

| partition | Energy (meV/atom) | Force (meV/A) | stress (GPa) | magmom (muB) |
| ---------- | ----------------- | ------------- | ------------ | ------------ |
| Train | 26.45 | 49 | 0.173 | 0.036 |
| Validation | 30.31 | 70 | 0.297 | 0.037 |
| Test | 30.80 | 66 | 0.296 | 0.038 |


# References

```txt
Deng, B. et al. CHGNet as a pretrained universal neural network potential for charge-informed atomistic modelling.
Nat. Mach. Intell. 1–11 (2023) doi:10.1038/s42256-023-00716-3.
```

#### Date: 2023.12.1
#### Author: Bowen Deng
66 changes: 66 additions & 0 deletions pretrained_models/CHGNet-MPtrj-2023.12.1-PES-2.7M/model.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
{
"@class": "Potential",
"@module": "matgl.apps.pes",
"@model_version": 1,
"metadata": null,
"kwargs": {
"model": {
"@class": "CHGNet",
"@module": "matgl.models._chgnet",
"@model_version": 1,
"init_args": {
"element_types": null,
"dim_state_feats": null,
"non_linear_bond_embedding": false,
"non_linear_angle_embedding": false,
"cutoff": 6.0,
"threebody_cutoff": 3.0,
"cutoff_exponent": 5,
"max_f": 32,
"learn_basis": false,
"num_blocks": 5,
"shared_bond_weights": "both",
"final_mlp_type": "mlp",
"final_hidden_dims": [
128,
128,
128
],
"final_dropout": 0.0,
"pooling_operation": "sum",
"readout_field": "atom_feat",
"activation_type": "swish",
"is_intensive": false,
"num_targets": 1,
"num_site_targets": 1,
"task_type": "regression",
"angle_update_hidden_dims": [],
"atom_conv_hidden_dims": [
128
],
"bond_conv_hidden_dims": [
128
],
"bond_update_hidden_dims": [
128
],
"conv_dropout": 0.0,
"dim_angle_embedding": 128,
"dim_atom_embedding": 128,
"dim_bond_embedding": 128,
"dim_state_embedding": null,
"layer_bond_weights": null,
"max_n": 63,
"normalization": "layer",
"normalize_hidden": false
}
},
"data_mean": "tensor(0.)",
"data_std": "tensor(1.)",
"element_refs": "tensor([ -3.4524, -0.2535, -3.1356, -3.5818, -7.5282, -8.2669, -7.7537,\n -8.3183, -5.6419, -0.0301, -1.9928, -1.5805, -4.3933, -6.2148,\n -6.3137, -5.6612, -3.6236, -0.0632, -1.7023, -3.7368, -6.8803,\n -9.4099, -9.5156, -9.5164, -9.0957, -7.9901, -6.4274, -5.5935,\n -3.3122, -0.8411, -3.2124, -4.8460, -4.6307, -4.8599, -3.1515,\n 0.8167, -1.5970, -3.4679, -7.7371, -9.5454, -10.5613, -9.9911,\n -6.7590, -8.2281, -7.0194, -5.0765, -1.8264, -0.3508, -2.5801,\n -3.9090, -4.0772, -3.8814, -2.4899, 3.0526, -2.2544, -3.9207,\n -7.2230, -7.4592, -6.5778, -6.7207, -5.1486, -6.7854, -11.7604,\n -16.2078, -6.4504, -6.4809, -6.3964, -6.3873, -6.3777, -2.8369,\n -6.4433, -10.6572, -12.3228, -11.8578, -10.5498, -9.1838, -8.1969,\n -6.0177, -2.7653, 0.6777, -1.6524, -3.1770, -3.3326, 0.0000,\n 0.0000, 0.0000, 0.0000, 0.0000, -4.2516, -9.0753, -10.3286,\n -12.5580, -12.7219, -14.3059])",
"calc_forces": true,
"calc_stresses": true,
"calc_hessian": false,
"calc_site_wise": true
}
}
Binary file not shown.
Binary file not shown.
32 changes: 32 additions & 0 deletions pretrained_models/CHGNet-MPtrj-2024.2.13-PES-11M/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Description

This model is a CHGNet universal potential trained from the Materials Project trajectory (MPtrj) dataset
that contains over 1.5 million structures with 89 elements.
This Matgl implementation has slight modification from original pytorch implementation by adding directed edge updates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace "pytorch" with "PyTorch" to adhere to the correct capitalization of the framework's name.

- original pytorch implementation
+ original PyTorch implementation

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
This Matgl implementation has slight modification from original pytorch implementation by adding directed edge updates.
This Matgl implementation has slight modification from original PyTorch implementation by adding directed edge updates.


# Training dataset

MPtrj-2022.9: Materials Project trajectory dataset that contains GGA and GGA+U static and relaxation calculations.
- Train-Val-Test splitting with mp-id: 0.95 - 0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For numerical ranges, consider using an en dash (–) instead of a hyphen (-) for improved readability. For example, "0.95–0.5" instead of "0.95 - 0.5".

- 0.95 - 0.5
+ 0.95–0.5

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
- Train-Val-Test splitting with mp-id: 0.95 - 0.5
- Train-Val-Test splitting with mp-id: 0.950.5

- Train set size: 1499043
- Validation set size: 79719
- Test set size: 0

# Performance metrics
## Training and validation errors

| partition | Energy (meV/atom) | Force (meV/A) | stress (GPa) | magmom (muB) |
| ---------- | ----------------- | ------------- | ------------ | ------------ |
| Train | 25.6 | 47.6 | 0.177 | 0.017 |
| Validation | 27.7 | 62.5 | 0.288 | 0.017 |


# References

```txt
Deng, B. et al. CHGNet as a pretrained universal neural network potential for charge-informed atomistic modelling.
Nat. Mach. Intell. 1–11 (2023) doi:10.1038/s42256-023-00716-3.
```

#### Date: 2024.2.13
#### Author: Bowen Deng
67 changes: 67 additions & 0 deletions pretrained_models/CHGNet-MPtrj-2024.2.13-PES-11M/model.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
"@class": "Potential",
"@module": "matgl.apps.pes",
"@model_version": 2,
"metadata": null,
"kwargs": {
"model": {
"@class": "CHGNet",
"@module": "matgl.models._chgnet",
"@model_version": 1,
"init_args": {
"element_types": null,
"dim_state_feats": null,
"non_linear_bond_embedding": false,
"non_linear_angle_embedding": false,
"cutoff": 6.0,
"threebody_cutoff": 3.0,
"cutoff_exponent": 5,
"max_f": 32,
"learn_basis": false,
"num_blocks": 5,
"shared_bond_weights": "both",
"final_mlp_type": "mlp",
"final_hidden_dims": [
256,
256,
256
],
"final_dropout": 0.0,
"pooling_operation": "sum",
"readout_field": "atom_feat",
"activation_type": "swish",
"is_intensive": false,
"num_targets": 1,
"num_site_targets": 1,
"task_type": "regression",
"angle_update_hidden_dims": [],
"atom_conv_hidden_dims": [
256
],
"bond_conv_hidden_dims": [
256
],
"bond_update_hidden_dims": [
256
],
"conv_dropout": 0.0,
"dim_angle_embedding": 256,
"dim_atom_embedding": 256,
"dim_bond_embedding": 256,
"dim_state_embedding": null,
"layer_bond_weights": null,
"max_n": 63,
"normalization": "layer",
"normalize_hidden": false
}
},
"data_mean": "tensor(0.)",
"data_std": "tensor(1.)",
"element_refs": "tensor([ -3.4524, -0.2535, -3.1356, -3.5818, -7.5282, -8.2669, -7.7537,\n -8.3183, -5.6419, -0.0301, -1.9928, -1.5805, -4.3933, -6.2148,\n -6.3137, -5.6612, -3.6236, -0.0632, -1.7023, -3.7368, -6.8803,\n -9.4099, -9.5156, -9.5164, -9.0957, -7.9901, -6.4274, -5.5935,\n -3.3122, -0.8411, -3.2124, -4.8460, -4.6307, -4.8599, -3.1515,\n 0.8167, -1.5970, -3.4679, -7.7371, -9.5454, -10.5613, -9.9911,\n -6.7590, -8.2281, -7.0194, -5.0765, -1.8264, -0.3508, -2.5801,\n -3.9090, -4.0772, -3.8814, -2.4899, 3.0526, -2.2544, -3.9207,\n -7.2230, -7.4592, -6.5778, -6.7207, -5.1486, -6.7854, -11.7604,\n -16.2078, -6.4504, -6.4809, -6.3964, -6.3873, -6.3777, -2.8369,\n -6.4433, -10.6572, -12.3228, -11.8578, -10.5498, -9.1838, -8.1969,\n -6.0177, -2.7653, 0.6777, -1.6524, -3.1770, -3.3326, 0.0000,\n 0.0000, 0.0000, 0.0000, 0.0000, -4.2516, -9.0753, -10.3286,\n -12.5580, -12.7219, -14.3059])",
"calc_forces": true,
"calc_stresses": true,
"calc_hessian": false,
"calc_site_wise": true,
"debug_mode": false
}
}
Binary file not shown.
Binary file not shown.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ authors = [
{ name = "Ji Qi", email = "[email protected]" },
{ name = "Santiago Miret", email = "[email protected]" },
{ name = "Eliott Liu", email = "[email protected]" },
{ name = "Bowen Deng", email = "[email protected]" },
{ name = "Luis Barroso-Luque", email = "[email protected]" },
{ name = "Shyue Ping Ong", email = "[email protected]" },
]
description = "MatGL is a framework for graph deep learning for materials science."
Expand Down
2 changes: 1 addition & 1 deletion src/matgl/apps/pes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def forward(
st = lat.new_zeros([g.batch_size, 3, 3])
if self.calc_stresses:
st.requires_grad_(True)
lattice = lat @ (torch.eye(3).to(st.device) + st)
lattice = lat @ (torch.eye(3, device=lat.device) + st)
g.edata["lattice"] = torch.repeat_interleave(lattice, g.batch_num_edges(), dim=0)
g.edata["pbc_offshift"] = (g.edata["pbc_offset"].unsqueeze(dim=-1) * g.edata["lattice"]).sum(dim=1)
g.ndata["pos"] = (
Expand Down
19 changes: 11 additions & 8 deletions src/matgl/ext/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_graph(self, atoms: Atoms) -> tuple[dgl.DGLGraph, torch.Tensor, list]:
class PESCalculator(Calculator):
"""Potential calculator for ASE."""

implemented_properties = ("energy", "free_energy", "forces", "stress", "hessian")
implemented_properties = ("energy", "free_energy", "forces", "stress", "hessian", "magmoms")

def __init__(
self,
Expand All @@ -145,6 +145,7 @@ def __init__(
self.potential = potential
self.compute_stress = potential.calc_stresses
self.compute_hessian = potential.calc_hessian
self.compute_magmoms = potential.calc_site_wise
self.stress_weight = stress_weight
self.state_attr = state_attr
self.element_types = potential.model.element_types # type: ignore
Expand Down Expand Up @@ -172,18 +173,20 @@ def calculate(
graph, lattice, state_attr_default = Atoms2Graph(self.element_types, self.cutoff).get_graph(atoms)
# type: ignore
if self.state_attr is not None:
energies, forces, stresses, hessians = self.potential(graph, lattice, self.state_attr)
calc_result = self.potential(graph, lattice, self.state_attr)
else:
energies, forces, stresses, hessians = self.potential(graph, lattice, state_attr_default)
calc_result = self.potential(graph, lattice, state_attr_default)
self.results.update(
energy=energies.detach().cpu().numpy().item(),
free_energy=energies.detach().cpu().numpy().item(),
forces=forces.detach().cpu().numpy(),
energy=calc_result[0].detach().cpu().numpy().item(),
free_energy=calc_result[0].detach().cpu().numpy(),
forces=calc_result[1].detach().cpu().numpy(),
)
if self.compute_stress:
self.results.update(stress=stresses.detach().cpu().numpy() * self.stress_weight)
self.results.update(stress=calc_result[2].detach().cpu().numpy() * self.stress_weight)
if self.compute_hessian:
self.results.update(hessian=hessians.detach().cpu().numpy())
self.results.update(hessian=calc_result[3].detach().cpu().numpy())
if self.compute_magmoms:
self.results.update(magmoms=calc_result[4].detach().cpu().numpy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The update to the results dictionary to include magmoms when self.compute_magmoms is true is correctly implemented. It's important to ensure that the shape and data type of calc_result[4] match the expectations for magnetic moments in ASE to ensure compatibility.

Consider adding a check or conversion to ensure the data type and shape are correct before updating the results.

+        if self.compute_magmoms:
+            magmoms = calc_result[4].detach().cpu().numpy()
+            # Ensure magmoms is in the expected format for ASE.
+            self.results.update(magmoms=magmoms)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if self.compute_magmoms:
self.results.update(magmoms=calc_result[4].detach().cpu().numpy())
if self.compute_magmoms:
magmoms = calc_result[4].detach().cpu().numpy()
# Ensure magmoms is in the expected format for ASE.
self.results.update(magmoms=magmoms)



# for backward compatibility
Expand Down
7 changes: 4 additions & 3 deletions src/matgl/graph/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def create_line_graph(g: dgl.DGLGraph, threebody_cutoff: float, directed: bool =
Args:
g: DGL graph
threebody_cutoff (float): cutoff for three-body interactions
directed (bool): Whether to create a directed line graph, or an m3gnet 3body line graph (default: False, m3gnet)
directed (bool): Whether to create a directed line graph, or an M3gnet 3body line graph
Default = False (M3Gnet)

Returns:
l_g: DGL graph containing three body information from graph
Expand Down Expand Up @@ -226,10 +227,10 @@ def _create_directed_line_graph(graph: dgl.DGLGraph, threebody_cutoff: float) ->
threebody_cutoff: cutoff for three-body interactions

Returns:
line_graph: DGL graph line graph of pruned graph to three body cutoff
line_graph: DGL line graph of pruned graph to three body cutoff
"""
with torch.no_grad():
pg = prune_edges_by_features(graph, feat_name="bond_dist", condition=lambda x: x > threebody_cutoff)
pg = prune_edges_by_features(graph, feat_name="bond_dist", condition=lambda x: torch.gt(x, threebody_cutoff))
src_indices, dst_indices = pg.edges()
images = pg.edata["pbc_offset"]
all_indices = torch.arange(pg.number_of_nodes(), device=graph.device).unsqueeze(dim=0)
Expand Down
Loading