Skip to content

Commit

Permalink
fix head bug
Browse files Browse the repository at this point in the history
  • Loading branch information
birdyLinch committed Oct 13, 2024
1 parent 988afa7 commit 93335c2
Show file tree
Hide file tree
Showing 36 changed files with 1,486 additions and 41 deletions.
8 changes: 7 additions & 1 deletion mace/cli/plot_neighbor_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def main() -> None:

# Setup
tools.set_seeds(args.seed)
tools.setup_logger(level=args.log_level, tag=log_tag, directory=args.log_dir, rank=rank)

reduced_log_tag = "reduced_save_log"

tools.setup_logger(level=args.log_level, tag=reduced_log_tag, directory=args.log_dir, rank=rank)

if args.distributed:
torch.cuda.set_device(local_rank)
Expand Down Expand Up @@ -1118,13 +1121,16 @@ def plot_species_neighbors(raw, element_symbols):

if rank == 0:
# Save entire model
# name length hack
tag = tag[100:]
if swa_eval:
model_path = Path(args.checkpoints_dir) / (tag + "_swa.model")
else:
model_path = Path(args.checkpoints_dir) / (tag + ".model")
logging.info(f"Saving model to {model_path}")
if args.save_cpu:
model = model.to("cpu")

torch.save(model, model_path)

#if swa_eval:
Expand Down
27 changes: 25 additions & 2 deletions mace/cli/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@
from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz, get_dataset_from_h5, get_dataset_from_extxyzs, get_dataset_from_jsonbz2s
from mace.tools.utils import AtomicNumberTable

import time
import random
from functools import wraps

def retry_with_backoff(retries=5, backoff_in_seconds=1):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
x = 0
while True:
try:
return func(*args, **kwargs)
except BlockingIOError as e:
if x == retries:
raise e
sleep = (backoff_in_seconds * 2 ** x +
random.uniform(0, 1))
time.sleep(sleep)
x += 1
return wrapper
return decorator

def compute_stats_target(
file: str,
Expand Down Expand Up @@ -94,11 +115,13 @@ def get_prime_factors(n: int):
return factors

# Define Task for Multiprocessiing
@retry_with_backoff()
def multi_train_hdf5(process, h5_prefix, drop_last, split_train):
with h5py.File(h5_prefix + "train/train_" + str(process)+".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_train, process, f)

@retry_with_backoff()
def multi_valid_hdf5(process, h5_prefix, drop_last, split_valid):
with h5py.File(h5_prefix + "val/val_" + str(process)+".h5", "w") as f:
f.attrs["drop_last"] = drop_last
Expand Down Expand Up @@ -218,7 +241,7 @@ def main():

processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_train_hdf5, args=[i, args.h5_prefix, drop_last, split_train[i]])
p = mp.Process(target=multi_train_hdf5, args=[args.idx if args.idx is not None else i, args.h5_prefix, drop_last, split_train[i]])
p.start()
processes.append(p)

Expand Down Expand Up @@ -267,7 +290,7 @@ def main():

processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_valid_hdf5, args=[i, args.h5_prefix, drop_last, split_valid[i]])
p = mp.Process(target=multi_valid_hdf5, args=[args.idx if args.idx is not None else i, args.h5_prefix, drop_last, split_valid[i]])
p.start()
processes.append(p)

Expand Down
4 changes: 2 additions & 2 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def main() -> None:
for head, head_args in args.heads.items():
logging.info(f"============= Reading dataset {head} and compute ===========")

if head_args.transform == "stress_kbar2evA":
if head_args.get("transform", None) and head_args.transform == "stress_kbar2evA":
def stress_kbar2evA(atomic_data):
atomic_data["stress"] = atomic_data["stress"] * -1e-1 * ase.units.GPa
return atomic_data
Expand Down Expand Up @@ -477,7 +477,7 @@ def stress_kbar2evA(atomic_data):
train_set = ConcatDataset(train_sets.values())

# mask dataset
if True:
if False:

# Now apply the filter function to your train_set
# masked_indices = filter_data(train_set)
Expand Down
7 changes: 7 additions & 0 deletions mace/data/atomic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class AtomicData(torch_geometric.data.Data):
forces_weight: torch.Tensor
stress_weight: torch.Tensor
virials_weight: torch.Tensor
alex_config_id: str

def __init__(
self,
Expand All @@ -63,6 +64,7 @@ def __init__(
virials: Optional[torch.Tensor], # [1,3,3]
dipole: Optional[torch.Tensor], # [, 3]
charges: Optional[torch.Tensor], # [n_nodes, ]
alex_config_id: Optional[str],
):
# Check shapes
num_nodes = node_attrs.shape[0]
Expand Down Expand Up @@ -106,6 +108,7 @@ def __init__(
"virials": virials,
"dipole": dipole,
"charges": charges,
"alex_config_id": alex_config_id,
}
super().__init__(**data)

Expand All @@ -131,6 +134,9 @@ def from_config(
except:
head = torch.tensor(len(heads) - 1, dtype=torch.long)


alex_config_id = config.alex_config_id

cell = (
torch.tensor(config.cell, dtype=torch.get_default_dtype())
if config.cell is not None
Expand Down Expand Up @@ -223,6 +229,7 @@ def from_config(
virials=virials,
dipole=dipole,
charges=charges,
alex_config_id=alex_config_id,
)


Expand Down
4 changes: 2 additions & 2 deletions mace/data/hdf5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def __getitem__(self, index):
dipole=unpack_value(subgrp["dipole"][()]),
charges=unpack_value(subgrp["charges"][()]),
weight=unpack_value(subgrp["weight"][()]),
head=unpack_value(subgrp["head"][()]) if hasattr(subgrp, "head") else None,
head=None, # do not asign head according to h5
energy_weight=unpack_value(subgrp["energy_weight"][()]),
forces_weight=unpack_value(subgrp["forces_weight"][()]),
stress_weight=unpack_value(subgrp["stress_weight"][()]),
virials_weight=unpack_value(subgrp["virials_weight"][()]),
config_type=unpack_value(subgrp["config_type"][()]),
pbc=unpack_value(subgrp["pbc"][()]),
cell=unpack_value(subgrp["cell"][()]),
alex_config_id=unpack_value(subgrp["alex_config_id"][()]) if hasattr(subgrp, "alex_config_id") else None,
alex_config_id=unpack_value(subgrp["alex_config_id"][()]) if "alex_config_id" in subgrp else None,
)
if config.head is None:
config.head = self.kwargs.get("head")
Expand Down
84 changes: 70 additions & 14 deletions mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

DEFAULT_CONFIG_TYPE = "Default"
DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0}
N_PROC=1


@dataclass
Expand Down Expand Up @@ -496,12 +497,14 @@ def read_atoms_file(identifier):

def read_atoms_jsonbz2_go(identifier):
trajs = []
trajs_ids = []
with bz2.open(identifier, 'rt') as f:
data = json.load(f)

for entry in data.keys():
for traj_idx, system in enumerate(data[entry]):
traj = []
config_ids = []
for image_idx, image in enumerate(system['steps']):
positions = np.array(
[site['xyz'] for site in image['structure']['sites']]
Expand Down Expand Up @@ -536,14 +539,51 @@ def read_atoms_jsonbz2_go(identifier):
atoms.info['energy'] = energy
atoms.arrays['forces'] = forces
atoms.info['stress'] = stress
config_id = os.path.basename(identifier).split('.')[0] + f"{entry}-{traj_idx}-{image_idx}"
config_id = os.path.basename(identifier).split('.')[0] + f"-{entry}-{traj_idx}-{image_idx}"
atoms.info['alex_config_id'] = config_id
traj.append(atoms)
config_ids.append(config_id)
# put into trajs
trajs.append(traj)
trajs_ids.append(config_ids)

trajs = alex_traj_removing(trajs)
trajs = alex_traj_subsampling(trajs)
total_trajs = len(trajs)
total_images = sum(len(traj) for traj in trajs)

trajs, removed_trajs_ids = alex_traj_removing(trajs, trajs_ids)
trajs, subsampled_trajs_ids = alex_traj_subsampling(trajs, removed_trajs_ids)

# logging:
selected_trajs = len([traj for traj in trajs if traj])
selected_images = sum(len(traj) for traj in trajs)


# Prepare the log content with percentages
traj_percentage = (selected_trajs / total_trajs) * 100 if total_trajs > 0 else 0
image_percentage = (selected_images / total_images) * 100 if total_images > 0 else 0

log_content = [
f"Trajectories: \t{selected_trajs} \t/ \t{total_trajs} \t[{traj_percentage:.2f}%]",
f"Images: \t{selected_images} \t/ \t{total_images} \t[{image_percentage:.2f}%]\n"
]

# Process each trajectory for detailed logging
for orig_traj_ids, final_traj_ids in zip(trajs_ids, subsampled_trajs_ids):
if not final_traj_ids: # This trajectory was removed
entry = orig_traj_ids[0].split('-')[2] # Extract entry from the first config_id
log_content.append(f"{entry}-{traj_idx}: removed")
else:
entry = final_traj_ids[0].split('-')[1]
traj_idx = final_traj_ids[0].split('-')[2]
kept_images = len(final_traj_ids)
total_images = len(orig_traj_ids)
kept_indices = [int(config_id.split('-')[-1]) for config_id in final_traj_ids]
log_content.append(f"{entry}-{traj_idx}: {kept_images}/{total_images} [{', '.join(map(str, kept_indices))}]")

# Write the log file
log_filename = f"{os.path.basename(identifier).split('.')[0]}.processing.log"
with open(log_filename, 'w') as log_file:
log_file.write('\n'.join(log_content))

return [atom for sublist in trajs for atom in sublist]

Expand Down Expand Up @@ -600,9 +640,10 @@ def max_stress_forces_energy_per_atom(traj):
max_energy_per_atom_value = energy_per_atom_value
return max_stress_value, max_forces_value, max_energy_per_atom_value

def alex_traj_removing(trajs):
def alex_traj_removing(trajs, trajs_ids):
filtered_trajs = []
for traj in trajs:
filtered_trajs_ids = []
for traj, config_ids in zip(trajs, trajs_ids):
stress_value, forces_value, energy_per_atom_value = max_stress_forces_energy_per_atom(traj)
final_forces_norm = np.linalg.norm(traj[-1].arrays['forces'], axis=-1).max()

Expand All @@ -611,39 +652,54 @@ def alex_traj_removing(trajs):
forces_value <= 300 and
forces_value > 0.0 and
energy_per_atom_value <= 2.0 and
len(traj) >= 4 and
final_forces_norm <= 0.005):
filtered_trajs.append(traj)
filtered_trajs_ids.append(config_ids)
else:
filtered_trajs.append([])
filtered_trajs_ids.append([])

return filtered_trajs
return filtered_trajs, filtered_trajs_ids

def alex_traj_subsampling(trajs):
def alex_traj_subsampling(trajs, trajs_ids):
subsampled_trajs = []
subsampled_trajs_ids = []

for traj in trajs:
alex_e0s = {1: -1.11734008, 2: 0.00096759, 3: -0.29754725, 4: -0.01781697, 5: -0.26885011, 6: -1.26173507, 7: -3.12438806, 8: -1.54838784, 9: -0.51882044, 10: -0.01241601, 11: -0.22883163, 12: -0.00951015, 13: -0.21630193, 14: -0.8263903, 15: -1.88816619, 16: -0.89160769, 17: -0.25828273, 18: -0.04925973, 19: -0.22697913, 20: -0.0927795, 21: -2.11396364, 22: -2.50054871, 23: -3.70477179, 24: -5.60261985, 25: -5.32541181, 26: -3.52004933, 27: -1.93555024, 28: -0.9351969, 29: -0.60025846, 30: -0.1651332, 31: -0.32990651, 32: -0.77971828, 33: -1.68367812, 34: -0.76941032, 35: -0.22213843, 36: -0.0335879, 37: -0.1881724, 38: -0.06826294, 39: -2.17084228, 40: -2.28579303, 41: -3.13429753, 42: -4.60211419, 43: -3.45201492, 44: -2.38073513, 45: -1.46855515, 46: -1.4773126, 47: -0.33954585, 48: -0.16843877, 49: -0.35470981, 50: -0.83642657, 51: -1.41101987, 52: -0.65740879, 53: -0.18964571, 54: -0.00857582, 55: -0.13771876, 56: -0.03457659, 57: -0.45580806, 58: -1.3309175, 59: -0.29671824, 60: -0.30391193, 61: -0.30898427, 62: -0.25470891, 63: -8.38001538, 64: -10.38896525, 65: -0.3059505, 66: -0.30676216, 67: -0.30874667, 68: -0.31610927, 69: -0.25190039, 70: -0.06431414, 71: -0.31997586, 72: -3.52770927, 73: -3.54492209, 74: -4.65658356, 75: -4.70108713, 76: -2.88257209, 77: -1.46779304, 78: -0.50269936, 79: -0.28801193, 80: -0.12454674, 81: -0.31737194, 82: -0.77644932, 83: -1.32627283, 89: -0.26827152, 90: -0.90817426, 91: -2.47653193, 92: -4.90438537, 93: -7.63378961, 94: -10.77237713}

for traj, config_ids in zip(trajs, trajs_ids):
# Skip empty trajectories
if not traj:
subsampled_trajs.append([])
subsampled_trajs_ids.append([])
continue

zs = traj[0].numbers
E0 = sum(alex_e0s[z] for z in zs if z in alex_e0s)

# Remove first image if trajectory has more than one atom
atom_list = traj[1:] if len(traj) > 1 else traj
config_list = config_ids[1:] if len(traj) > 1 else config_ids

# Extract energies
try:
energies = [atom.info['energy'] for atom in atom_list]
energies = [(atom.info['energy'] - E0) for atom in atom_list]
except KeyError:
print(f"Warning: 'energy' not found in atom.info for a trajectory. Skipping this trajectory.")
subsampled_trajs.append(traj) # Keep the original trajectory
# subsampled_trajs.append(traj) # Keep the original trajectory
continue

# Sample indices
indices = sample_energy_time_series_reverse(energies, relative_threshold=0.001)

# Create subsampled trajectory
subsampled_traj = [atom_list[i] for i in indices]
subsampled_configs = [config_list[i] for i in indices]
subsampled_trajs.append(subsampled_traj)
subsampled_trajs_ids.append(subsampled_configs)

return subsampled_trajs
return subsampled_trajs, subsampled_trajs_ids

def read_atoms_jsonbz2(identifier):
atom_list = []
Expand Down Expand Up @@ -697,7 +753,7 @@ def atoms_from_oc20(file_path, positions_key='coordinates', numbers_key="species
filenames = [f for f in os.listdir(file_path) if f.endswith(".extxyz")]
identifiers = [os.path.join(file_path, f) for f in filenames if f.endswith(".extxyz")]

with mp.Pool(16) as pool:
with mp.Pool(N_PROC) as pool:
results = list(tqdm(pool.imap(read_atoms_file, identifiers), total=len(identifiers)))

# Flatten the list of lists
Expand All @@ -710,7 +766,7 @@ def atoms_from_alex_go(file_path, positions_key='coordinates', numbers_key="spec
filenames = [f for f in os.listdir(file_path) if f.endswith(".json.bz2") and f.startswith("alex_go")]
identifiers = [os.path.join(file_path, f) for f in filenames]

with mp.Pool(4) as pool:
with mp.Pool(N_PROC) as pool:
results = list(tqdm(pool.imap(read_atoms_jsonbz2_go, identifiers), total=len(identifiers)))

# Flatten the list of lists
Expand All @@ -722,7 +778,7 @@ def atoms_from_alex(file_path, positions_key='coordinates', numbers_key="species
filenames = [f for f in os.listdir(file_path) if f.endswith(".json.bz2") and f.startswith("alexandria")]
identifiers = [os.path.join(file_path, f) for f in filenames]

with mp.Pool(16) as pool:
with mp.Pool(N_PROC) as pool:
results = list(tqdm(pool.imap(read_atoms_jsonbz2, identifiers), total=len(identifiers)))

# Flatten the list of lists
Expand Down
4 changes: 4 additions & 0 deletions mace/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
RealAgnosticResidualInteractionBlock,
RealAgnosticInteractionGateBlock,
RealAgnosticResidualInteractionGateBlock,
RealAgnosticSimplifiedDensityInteractionBlock,
RealAgnosticSimplifiedDensityResidualInteractionBlock,
RealAgnosticDensityNormalizedInteractionGateBlock,
RealAgnosticDensityNormalizedNoScaleInteractionGateBlock,
RealAgnosticDensityInjuctedInteractionGateBlock,
Expand Down Expand Up @@ -85,6 +87,8 @@
"RealAgnosticDensityInjuctedNoScaleNoBiasResidualInteractionBlock": RealAgnosticDensityInjuctedNoScaleNoBiasResidualInteractionGateBlock,
"RealAgnosticDensityInjuctedNoScaleResidualInteractionBlock": RealAgnosticDensityInjuctedNoScaleResidualInteractionGateBlock,
"RealAgnosticDensityInjuctUnnormalizedNoScaleInteractionBlock": RealAgnosticDensityInjuctUnnormalizedNoScaleInteractionGateBlock,
"RASimpleDensityIntBlock": RealAgnosticSimplifiedDensityInteractionBlock,
"RASimpleDensityResidualIntBlock": RealAgnosticSimplifiedDensityResidualInteractionBlock,
}

scaling_classes: Dict[str, Callable] = {
Expand Down
Loading

0 comments on commit 93335c2

Please sign in to comment.