Skip to content

Commit

Permalink
add alex data sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
birdyLinch committed Oct 7, 2024
1 parent ffe3b89 commit 988afa7
Show file tree
Hide file tree
Showing 39 changed files with 841 additions and 1,452 deletions.
28 changes: 19 additions & 9 deletions mace/cli/plot_neighbor_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,11 +1081,21 @@ def plot_species_neighbors(raw, element_symbols):
# all_data_loaders[test_name] = test_loader

for swa_eval in swas:
epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=swa_eval,
device=device,
)
if args.ckpt_path is None:
epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=swa_eval,
device=device,
)
else:
ckpt_name = checkpoint_handler.io._get_checkpoint_filename(int(args.ckpt_path), None)
epoch = checkpoint_handler.load(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
path=Path(args.checkpoints_dir) / ckpt_name,
device=device,
)
tag = tag + f"_epoch-{epoch}"

model.to(device)
if args.distributed:
distributed_model = DDP(model, device_ids=[local_rank])
Expand Down Expand Up @@ -1117,10 +1127,10 @@ def plot_species_neighbors(raw, element_symbols):
model = model.to("cpu")
torch.save(model, model_path)

if swa_eval:
torch.save(model, Path(args.model_dir) / (args.name + "_swa.model"))
else:
torch.save(model, Path(args.model_dir) / (args.name + ".model"))
#if swa_eval:
# torch.save(model, Path(args.model_dir) / (args.name + "_swa.model"))
#else:
# torch.save(model, Path(args.model_dir) / (args.name + ".model"))

if args.distributed:
torch.distributed.barrier()
Expand Down
20 changes: 18 additions & 2 deletions mace/cli/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mace.data.utils import save_configurations_as_HDF5
from mace.modules import compute_statistics
from mace.tools import torch_geometric
from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz, get_dataset_from_h5, get_dataset_from_extxyzs
from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz, get_dataset_from_h5, get_dataset_from_extxyzs, get_dataset_from_jsonbz2s
from mace.tools.utils import AtomicNumberTable


Expand Down Expand Up @@ -173,6 +173,16 @@ def main():
test_path=args.test_file,
seed=args.seed,
)
elif "alexandria" in args.train_file and os.path.isdir(args.train_file):
collections, atomic_energies_dict, _ = get_dataset_from_jsonbz2s(
train_path=args.train_file,
valid_path=args.valid_file,
valid_fraction=args.valid_fraction,
config_type_weights=config_type_weights,
test_path=args.test_file,
seed=args.seed,
)



# Atomic number table
Expand All @@ -190,6 +200,12 @@ def main():
assert isinstance(zs_list, list)
z_table = tools.get_atomic_number_table_from_zs(zs_list)

try:
import torch
torch.save(z_table, 'z_table.pt')
except:
print(z_table)

logging.info("Preparing training set")
if args.shuffle:
random.shuffle(collections.train)
Expand Down Expand Up @@ -283,5 +299,5 @@ def multi_test_hdf5(process, name):

if __name__ == "__main__":
mp.set_start_method('spawn')
os.chdir("/mnt/petrelfs/linchen/FoundationalModel/mace_multi_head_interface")
#os.chdir("/mnt/petrelfs/linchen/FoundationalModel/mace_multi_head_interface")
main()
78 changes: 76 additions & 2 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,59 @@
from mace.tools.utils import AtomicNumberTable
from torch.utils.data import ConcatDataset
from box import Box
from tqdm import tqdm


def bad_force(data):
forces = data["forces"]
forces_norm_max = forces.norm(dim=-1).max().item()
if forces_norm_max > 300.0:
return True
else:
return False

def bad_iso(data):
forces = data["forces"]
n_atoms = forces.size(0)

if n_atoms > 1:
return False
else:
return True

def bad_energy(data):
energy = data["energy"].item()
forces = data["forces"]
n_atoms = forces.size(0)
e_per_atom = energy / n_atoms

if -20.0 > e_per_atom or e_per_atom > 2.0:
return True
else:
return False

def bad_stress(data):
stress = data['stress']
stress_norm = stress.abs().max().item()
if stress_norm > 1.0:
return True
else:
return False

import multiprocessing as mp
from tqdm import tqdm

def is_bad(data):
"""Check if a sample is 'bad' based on force, energy, and stress."""
return bad_force(data) or bad_energy(data) or bad_stress(data) or bad_iso(data)

def filter_data(train_set):
"""Filter out 'bad' samples from the dataset using multiprocessing."""
with mp.Pool(4) as pool:
mask = list(tqdm(pool.imap(is_bad, train_set), total=len(train_set)))

# Return the indices of the good data points (where mask is False)
return [i for i, bad in enumerate(mask) if not bad]

def format_number(num):
if num >= 1_000_000_000:
Expand Down Expand Up @@ -248,6 +301,15 @@ def main() -> None:

for head, head_args in args.heads.items():
logging.info(f"============= Reading dataset {head} and compute ===========")

if head_args.transform == "stress_kbar2evA":
def stress_kbar2evA(atomic_data):
atomic_data["stress"] = atomic_data["stress"] * -1e-1 * ase.units.GPa
return atomic_data
head_transform = stress_kbar2evA
else:
head_transform = None

if head_args.train_file.endswith(".xyz"):
# TODO: test this branch
if head_args.valid_file is not None:
Expand Down Expand Up @@ -289,10 +351,10 @@ def main() -> None:
head_args.valid_set = data.HDF5Dataset(head_args.valid_file, r_max=head_args.r_max, z_table=z_table, head=head, heads=list(args.heads.keys()))
else: # This case would be for when the file path is to a directory of multiple .h5 files
head_args.train_set = data.dataset_from_sharded_hdf5(
head_args.train_file, r_max=head_args.r_max, z_table=z_table, head=head, heads=list(args.heads.keys()), rank=rank
head_args.train_file, r_max=head_args.r_max, z_table=z_table, head=head, heads=list(args.heads.keys()), rank=rank, transform=head_transform
)
head_args.valid_set = data.dataset_from_sharded_hdf5(
head_args.valid_file, r_max=head_args.r_max, z_table=z_table, head=head, heads=list(args.heads.keys()), rank=rank
head_args.valid_file, r_max=head_args.r_max, z_table=z_table, head=head, heads=list(args.heads.keys()), rank=rank, transform=head_transform
)

logging.info(f"Dataset {head} size --> {format_number(len(head_args.train_set))}")
Expand All @@ -304,10 +366,14 @@ def main() -> None:
subset_size = int(ratio * len(head_args.train_set))
remaining_size = len(head_args.train_set) - subset_size

val_subset_size = int(ratio * len(head_args.valid_set))
val_remaining_size = len(head_args.valid_set) - val_subset_size
# Split the dataset
head_args.train_set, _ = random_split(head_args.train_set, [subset_size, remaining_size])
head_args.valid_set, _ = random_split(head_args.valid_set, [val_subset_size, val_remaining_size])

logging.info(f"Dataset {head} subsampled size --> {format_number(len(head_args.train_set))}")
logging.info(f"Dataset {head} subsampled valid size --> {format_number(len(head_args.valid_set))}")


# head specific train_sampler
Expand Down Expand Up @@ -410,6 +476,14 @@ def main() -> None:

train_set = ConcatDataset(train_sets.values())

# mask dataset
if True:

# Now apply the filter function to your train_set
# masked_indices = filter_data(train_set)

masked_indices = [i for i in tqdm(range(len(train_set))) if not is_bad(train_set[i])]
train_set = torch.utils.data.Subset(train_set, masked_indices)

if args.model == "AtomicDipolesMACE":
atomic_energies = None
Expand Down
4 changes: 4 additions & 0 deletions mace/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
config_from_atoms,
config_from_atoms_list,
load_from_xyz,
load_from_jsonbz2s,
load_from_jsonbz2s_go,
load_from_h5,
load_from_extxyzs,
random_train_valid_split,
Expand All @@ -33,4 +35,6 @@
"dataset_from_sharded_hdf5",
"save_AtomicData_to_HDF5",
"save_configurations_as_HDF5",
"load_from_jsonbz2s",
"load_from_jsonbz2s_go",
]
8 changes: 6 additions & 2 deletions mace/data/hdf5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, file_path, r_max, z_table, **kwargs):
except KeyError:
self.drop_last = False
self.kwargs = kwargs
self.transform = kwargs['transform'] if 'transform' in kwargs else None

@property
def file(self):
Expand Down Expand Up @@ -57,7 +58,7 @@ def __getitem__(self, index):
positions=subgrp["positions"][()],
energy=unpack_value(subgrp["energy"][()]),
forces=unpack_value(subgrp["forces"][()]),
stress=unpack_value(subgrp["stress"][()]),
stress=unpack_value(subgrp["stress"][()]),
virials=unpack_value(subgrp["virials"][()]),
dipole=unpack_value(subgrp["dipole"][()]),
charges=unpack_value(subgrp["charges"][()]),
Expand All @@ -70,6 +71,7 @@ def __getitem__(self, index):
config_type=unpack_value(subgrp["config_type"][()]),
pbc=unpack_value(subgrp["pbc"][()]),
cell=unpack_value(subgrp["cell"][()]),
alex_config_id=unpack_value(subgrp["alex_config_id"][()]) if hasattr(subgrp, "alex_config_id") else None,
)
if config.head is None:
config.head = self.kwargs.get("head")
Expand All @@ -82,13 +84,15 @@ def __getitem__(self, index):
)
except:
import ipdb; ipdb.set_trace()
if self.transform:
atomic_data = self.transform(atomic_data)
return atomic_data


def dataset_from_sharded_hdf5(
files: List, z_table: AtomicNumberTable, r_max: float, **kwargs
):
files = glob(files + "/*")
files = glob(files + "/*.h5")
datasets = []

if 'rank' not in kwargs or ('rank' in kwargs and kwargs['rank'] == 0):
Expand Down
Loading

0 comments on commit 988afa7

Please sign in to comment.