Skip to content

Commit

Permalink
fix case with multihead foundation model
Browse files Browse the repository at this point in the history
ilyes319 committed Nov 12, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent abd7e5e commit 1300ad1
Showing 7 changed files with 610 additions and 255 deletions.
29 changes: 23 additions & 6 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,6 @@
LRScheduler,
check_path_ase_read,
convert_to_json_format,
create_error_table,
dict_to_array,
extract_config_mace_model,
get_atomic_energies,
@@ -49,9 +48,11 @@
get_params_options,
get_swa,
print_git_commit,
remove_pt_head,
setup_wandb,
)
from mace.tools.slurm_distributed import DistributedEnvironment
from mace.tools.tables_utils import create_error_table
from mace.tools.utils import AtomicNumberTable


@@ -115,10 +116,6 @@ def run(args: argparse.Namespace) -> None:
commit = print_git_commit()
model_foundation: Optional[torch.nn.Module] = None
if args.foundation_model is not None:
if args.multiheads_finetuning:
assert (
args.E0s != "average"
), "average atomic energies cannot be used for multiheads finetuning"
if args.foundation_model in ["small", "medium", "large"]:
logging.info(
f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint."
@@ -148,6 +145,27 @@ def run(args: argparse.Namespace) -> None:
f"Using foundation model {args.foundation_model} as initial checkpoint."
)
args.r_max = model_foundation.r_max.item()
if (
args.foundation_model not in ["small", "medium", "large"]
and args.pt_train_file is None
):
logging.warning(
"Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file."
)
args.multiheads_finetuning = False
if args.multiheads_finetuning:
assert (
args.E0s != "average"
), "average atomic energies cannot be used for multiheads finetuning"
# check that the foundation model has a single head, if not, use the first head
if hasattr(model_foundation, "heads"):
if len(model_foundation.heads) > 1:
logging.warning(
"Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head."
)
model_foundation = remove_pt_head(
model_foundation, args.foundation_head
)
else:
args.multiheads_finetuning = False

@@ -587,7 +605,6 @@ def run(args: argparse.Namespace) -> None:
distributed_model = DDP(model, device_ids=[local_rank])
else:
distributed_model = None

tools.train(
model=model,
loss_fn=loss_fn,
7 changes: 7 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
@@ -360,6 +360,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
type=str2bool,
default=True,
)
parser.add_argument(
"--foundation_model_head",
help="Name of the head to use for fine-tuning",
type=str,
default=None,
required=False,
)
parser.add_argument(
"--weight_pt_head",
help="Weight of the pretrained head in the loss function",
20 changes: 9 additions & 11 deletions mace/tools/finetuning_utils.py
Original file line number Diff line number Diff line change
@@ -73,10 +73,10 @@ def load_foundations_elements(
model.interactions[i].linear.weight = torch.nn.Parameter(
model_foundations.interactions[i].linear.weight.clone()
)
if (
model.interactions[i].__class__.__name__
in ["RealAgnosticResidualInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"]
):
if model.interactions[i].__class__.__name__ in [
"RealAgnosticResidualInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
]:
model.interactions[i].skip_tp.weight = torch.nn.Parameter(
model_foundations.interactions[i]
.skip_tp.weight.reshape(
@@ -101,19 +101,17 @@ def load_foundations_elements(
.clone()
/ (num_species_foundations / num_species) ** 0.5
)
if (
model.interactions[i].__class__.__name__
in ["RealAgnosticDensityInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"]
):
if model.interactions[i].__class__.__name__ in [
"RealAgnosticDensityInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
]:
# Assuming only 1 layer in density_fn
getattr(model.interactions[i].density_fn, "layer0").weight = (
torch.nn.Parameter(
getattr(
model_foundations.interactions[i].density_fn,
"layer0",
)
.weight
.clone()
).weight.clone()
)
)
# Transferring products
1 change: 0 additions & 1 deletion mace/tools/model_script_utils.py
Original file line number Diff line number Diff line change
@@ -53,7 +53,6 @@ def configure_model(
model_config_foundation["atomic_inter_shift"] = (
_determine_atomic_inter_shift(args.mean, heads)
)

model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads)
args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"]
args.model = "FoundationMACE"
Loading

0 comments on commit 1300ad1

Please sign in to comment.