diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 478d0a9b..866ffeb5 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -180,6 +180,7 @@ def main() -> None: if args.multiheads_finetuning: logging.info("Using multiheads finetuning mode") + args.loss = "universal" if heads is not None: heads = list(dict.fromkeys(["pbe_mp"] + heads)) args.heads = heads @@ -552,7 +553,6 @@ def main() -> None: model_config["atomic_inter_shift"] = [args.mean] * len(heads) model_config["atomic_inter_scale"] = [1.0] * len(heads) args.model = "FoundationMACE" - args.loss = "universal" model_config["heads"] = args.heads logging.info("Model configuration extracted from foundation model") logging.info("Using universal loss function for fine-tuning")