Skip to content

Commit

Permalink
Merge pull request #616 from ACEsuit/develop
Browse files Browse the repository at this point in the history
fix swa bigger than epoch
  • Loading branch information
ilyes319 authored Oct 2, 2024
2 parents 056a4c0 + 6b7d9c9 commit 96aa932
Show file tree
Hide file tree
Showing 9 changed files with 600 additions and 109 deletions.
6 changes: 3 additions & 3 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ class MACECalculator(Calculator):

def __init__(
self,
model_paths: Union[list, str] | None = None,
device: str | None = None,
models: Union[list[torch.nn.Module], torch.nn.Module] | None = None,
model_paths: Union[list, str, None] = None,
models: Union[list[torch.nn.Module], torch.nn.Module, None] = None,
device: str = "cpu",
energy_units_to_eV: float = 1.0,
length_units_to_A: float = 1.0,
default_dtype="",
Expand Down
13 changes: 12 additions & 1 deletion mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ def parse_args():
help="Head of the model to be converted to LAMMPS",
default=None,
)
parser.add_argument(
"--dtype",
type=str,
nargs="?",
help="Data type of the model to be converted to LAMMPS",
default="float64",
)
return parser.parse_args()


Expand Down Expand Up @@ -58,7 +65,11 @@ def main():
model_path,
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
model = model.double().to("cpu")
if args.dtype == "float64":
model = model.double().to("cpu")
elif args.dtype == "float32":
print("Converting model to float32, this may cause loss of precision.")
model = model.float().to("cpu")

if args.head is None:
head = select_head(model)
Expand Down
17 changes: 16 additions & 1 deletion mace/cli/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,22 @@ def pool_compute_stats(inputs: List):
pool.join()

results = [r.get() for r in tqdm.tqdm(re)]
return np.average(results, axis=0)

if not results:
raise ValueError(
"No results were computed. Check if the input files exist and are readable."
)

# Separate avg_num_neighbors, mean, and std
avg_num_neighbors = np.mean([r[0] for r in results])
means = np.array([r[1] for r in results])
stds = np.array([r[2] for r in results])

# Compute averages
mean = np.mean(means, axis=0).item()
std = np.mean(stds, axis=0).item()

return avg_num_neighbors, mean, std


def split_array(a: np.ndarray, max_size: int):
Expand Down
9 changes: 5 additions & 4 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def run(args: argparse.Namespace) -> None:
args.loss = "universal"
if (
args.foundation_model in ["small", "medium", "large"]
or "mp" in args.foundation_model
or args.pt_train_file is None
):
logging.info(
Expand Down Expand Up @@ -344,6 +343,7 @@ def run(args: argparse.Namespace) -> None:
atomic_energies_dict = {}
for head_config in head_configs:
if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0:
assert head_config.E0s is not None, "Atomic energies must be provided"
if check_path_ase_read(head_config.train_file) and head_config.E0s.lower() != "foundation":
atomic_energies_dict[head_config.head_name] = get_atomic_energies(
head_config.E0s, head_config.collections.train, head_config.z_table
Expand Down Expand Up @@ -403,7 +403,10 @@ def run(args: argparse.Namespace) -> None:
# )
atomic_energies = dict_to_array(atomic_energies_dict, heads)
for head_config in head_configs:
logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}")
try:
logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}")
except KeyError as e:
raise KeyError(f"Atomic number {e} not found in atomic_energies_dict for head {head_config.head_name}, add E0s for this atomic number") from e


valid_sets = {head: [] for head in heads}
Expand Down Expand Up @@ -627,9 +630,7 @@ def run(args: argparse.Namespace) -> None:
stop_first_test = True
for head_config in head_configs:
if check_path_ase_read(head_config.train_file):
print(head_config.test_file)
for name, subset in head_config.collections.tests:
print(name)
test_sets[name] = [
data.AtomicData.from_config(
config, z_table=z_table, cutoff=args.r_max, heads=heads
Expand Down
10 changes: 6 additions & 4 deletions mace/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def compute_statistics(
forces_list = []
num_neighbors = []
head_list = []
head_batch = []

for batch in data_loader:
head = batch.head
Expand All @@ -391,21 +392,22 @@ def compute_statistics(
) # {[n_graphs], }
forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], }
head_list.append(head) # {[n_graphs], }

head_batch.append(head[batch.batch])
_, receivers = batch.edge_index
_, counts = torch.unique(receivers, return_counts=True)
num_neighbors.append(counts)

atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs]
forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], }
head = torch.cat(head_list, dim=0) # [total_n_graphs]
head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs]

# mean = to_numpy(torch.mean(atom_energies)).item()
mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1))
# do the mean for each head
# rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item()
rms = to_numpy(
torch.sqrt(scatter_mean(src=torch.square(forces), index=head, dim=0))
torch.sqrt(
scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1)
)
)

avg_num_neighbors = torch.mean(
Expand Down
3 changes: 2 additions & 1 deletion mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,11 @@ def get_swa(
if args.start_swa is None:
args.start_swa = max(1, args.max_num_epochs // 4 * 3)
else:
if args.start_swa > args.max_num_epochs:
if args.start_swa >= args.max_num_epochs:
logging.warning(
f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}"
)
swas[-1] = False
if args.loss == "forces_only":
raise ValueError("Can not select Stage Two with forces only loss.")
if args.loss == "virials":
Expand Down
Loading

0 comments on commit 96aa932

Please sign in to comment.