Skip to content

Commit

Permalink
fix order of loss multihead
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Jun 10, 2024
1 parent bce718a commit 3f65c29
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def main() -> None:

if args.multiheads_finetuning:
logging.info("Using multiheads finetuning mode")
args.loss = "universal"
if heads is not None:
heads = list(dict.fromkeys(["pbe_mp"] + heads))
args.heads = heads
Expand Down Expand Up @@ -552,7 +553,6 @@ def main() -> None:
model_config["atomic_inter_shift"] = [args.mean] * len(heads)
model_config["atomic_inter_scale"] = [1.0] * len(heads)
args.model = "FoundationMACE"
args.loss = "universal"
model_config["heads"] = args.heads
logging.info("Model configuration extracted from foundation model")
logging.info("Using universal loss function for fine-tuning")
Expand Down

0 comments on commit 3f65c29

Please sign in to comment.