-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ep/qmml #666
Ep/qmml #666
Changes from all commits
fb188e1
384f778
09fb039
bb27705
c1cfd98
6c524c9
4266da0
b433573
4d4cc5c
f8494fd
2d23890
cd06b83
800c3b0
aff25bf
f4ca4ee
c465ce4
2156399
c86b404
1ebad7a
f99a432
3a399fa
64b5d2e
9517bd2
c503f6b
ff969cf
68dcf26
16ea5ca
ba6883c
eb70cb8
fcd5e2d
916ccdd
34c1fc1
77571cb
88c2064
9706526
024d27d
9eafe7b
cea9927
5097565
0e51fbc
5a57b74
1bf6c56
f8f294e
92a795e
edb1491
5289b03
23e4285
9f65452
4732a44
99911e1
be62272
74f9388
242c883
1e1f2bf
0f6fc8b
ca4acd7
fe4cdb1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
include src/schnetpack/train/ressources/partition_spline_for_robust_loss.npz | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,5 +57,8 @@ script-files = [ | |
"src/scripts/spkdeploy", | ||
] | ||
|
||
# Ensure package data such as resources are included | ||
package-data = { "schnetpack.train" = ["ressources/partition_spline_for_robust_loss.npz"] } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is needed to include the file when building the package |
||
|
||
[tool.setuptools.dynamic] | ||
version = {attr = "schnetpack.__version__"} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
defaults: | ||
- custom | ||
|
||
_target_: schnetpack.datasets.QCML | ||
|
||
datapath: ${run.data_dir}/qcml.db # data_dir is specified in train.yaml | ||
batch_size: 50 | ||
num_train: 0.90 | ||
num_val: 0.05 | ||
load_properties: [formation_energy,forces,charge,multiplicity] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don`t think it is possible to pass a list like this. If I remember correctly it should work like
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is possible to pass a list like this. |
||
version: 0.0.3 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# @package _global_ | ||
|
||
defaults: | ||
- override /data: qcml | ||
- override /model/representation: painn | ||
- override /model/representation/radial_basis: bernstein ### NEW ADDED FEATURE | ||
- override /task/scheduler: multistep ### NEW ADDED CONFIG | ||
|
||
run: | ||
experiment: qcml | ||
|
||
seed: 0 | ||
|
||
globals: | ||
cutoff: 10. | ||
lr: 1e-3 | ||
energy_key: formation_energy | ||
forces_key: forces | ||
total_charge_key: charge ### NEW ADDED FEATURE | ||
spin_key: multiplicity ### NEW ADDED FEATURE | ||
|
||
data: | ||
datapath: ??? | ||
load_properties: [formation_energy,forces,charge,multiplicity] | ||
batch_size: 50 | ||
num_train: 0.90 | ||
num_val: 0.05 | ||
num_workers: 4 | ||
num_val_workers: 4 | ||
distance_unit: Bohr | ||
property_units: | ||
energy: Hartree | ||
forces: Hartree/Bohr | ||
transforms: | ||
- _target_: schnetpack.transform.SubtractCenterOfMass | ||
- _target_: schnetpack.transform.RemoveOffsets ### NEW ADDED FEATURE | ||
property: ${globals.energy_key} | ||
remove_mean: True | ||
- _target_: schnetpack.transform.MatScipyNeighborList | ||
cutoff: ${globals.cutoff} | ||
- _target_: schnetpack.transform.CastTo32 | ||
|
||
model: | ||
representation: | ||
nuclear_embedding: | ||
_target_: schnetpack.nn.embedding.NuclearEmbedding | ||
max_z: 101 | ||
num_features: ${globals.representation_features} # same as n_atom_basis | ||
electronic_embeddings: ### NEW ADDED FEATURE | ||
- _target_: schnetpack.nn.embedding.ElectronicEmbedding | ||
property_key: ${globals.total_charge_key} | ||
num_features: ${model.representation.n_atom_basis} | ||
is_charged: true | ||
num_residual: 1 | ||
- _target_: schnetpack.nn.embedding.ElectronicEmbedding ### NEW ADDED FEATURE | ||
property_key: ${globals.spin_key} | ||
num_features: ${model.representation.n_atom_basis} | ||
is_charged: false | ||
num_residual: 1 | ||
output_modules: | ||
- _target_: schnetpack.atomistic.Atomwise | ||
output_key: ${globals.energy_key} | ||
n_in: ${model.representation.n_atom_basis} | ||
aggregation_mode: sum | ||
- _target_: schnetpack.atomistic.Forces | ||
energy_key: ${globals.energy_key} | ||
force_key: ${globals.forces_key} | ||
postprocessors: | ||
- _target_: schnetpack.transform.CastTo64 | ||
- _target_: schnetpack.transform.AddOffsets | ||
property: ${globals.energy_key} | ||
add_mean: True | ||
|
||
task: | ||
scheduler_args: | ||
milestones: [3,9,15,18,24,30,36] | ||
outputs: | ||
- _target_: schnetpack.task.ModelOutput | ||
name: ${globals.energy_key} | ||
loss_fn: | ||
_target_: schnetpack.train.AdaptiveLossFunction ### NEW ADDED FEATURE | ||
num_dims: 1 | ||
metrics: | ||
mae: | ||
_target_: torchmetrics.regression.MeanAbsoluteError | ||
rmse: | ||
_target_: torchmetrics.regression.MeanSquaredError | ||
squared: False | ||
loss_weight: 0.05 | ||
- _target_: schnetpack.task.ModelOutput | ||
name: ${globals.forces_key} | ||
loss_fn: | ||
_target_: schnetpack.train.AdaptiveLossFunction ### NEW ADDED FEATURE | ||
num_dims: 3 | ||
metrics: | ||
mae: | ||
_target_: torchmetrics.regression.MeanAbsoluteError | ||
rmse: | ||
_target_: torchmetrics.regression.MeanSquaredError | ||
squared: False | ||
loss_weight: 0.95 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
_target_: schnetpack.nn.radial.BernsteinRBF | ||
n_rbf: 32 | ||
cutoff: ${globals.cutoff} | ||
init_alpha: 0.95 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# @package task | ||
scheduler_cls: torch.optim.lr_scheduler.MultiStepLR | ||
scheduler_monitor: val_loss | ||
scheduler_args: | ||
milestones: ??? | ||
gamma: 0.5 | ||
last_epoch: -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From this file values are loaded which are used to approximate the partition function to for the adaptive loss fn