Skip to content

Commit

Permalink
fixed mixed pbc non pbc training
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Sep 4, 2024
1 parent 58f1f3a commit 6d768c2
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 39 deletions.
30 changes: 19 additions & 11 deletions mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,25 @@ def select_samples(
"Filtering configurations based on the finetuning set, "
f"filtering type: combinations, elements: {all_species_ft}"
)
if args.descriptors is not None:
logging.info("Loading descriptors")
descriptors = np.load(args.descriptors, allow_pickle=True)
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
for i, atoms in enumerate(atoms_list_pt):
atoms.info["mace_descriptors"] = descriptors[i]
atoms_list_pt_filtered = [
x
for x in atoms_list_pt
if filter_atoms(x, all_species_ft, "combinations")
]
if args.subselect != "random":
if args.descriptors is not None:
logging.info("Loading descriptors")
descriptors = np.load(args.descriptors, allow_pickle=True)
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
for i, atoms in enumerate(atoms_list_pt):
atoms.info["mace_descriptors"] = descriptors[i]
atoms_list_pt_filtered = [
x
for x in atoms_list_pt
if filter_atoms(x, all_species_ft, "combinations")
]
else:
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
atoms_list_pt_filtered = [
x
for x in atoms_list_pt
if filter_atoms(x, all_species_ft, "combinations")
]
else:
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
atoms_list_pt_filtered = [
Expand Down
6 changes: 3 additions & 3 deletions mace/data/atomic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def from_config(
) -> "AtomicData":
if heads is None:
heads = ["default"]
edge_index, shifts, unit_shifts = get_neighborhood(
edge_index, shifts, unit_shifts, cell = get_neighborhood(
positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell
)
indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table)
Expand All @@ -133,8 +133,8 @@ def from_config(
head = torch.tensor(len(heads) - 1, dtype=torch.long)

cell = (
torch.tensor(config.cell, dtype=torch.get_default_dtype())
if config.cell is not None
torch.tensor(cell, dtype=torch.get_default_dtype())
if cell is not None
else torch.tensor(
3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype()
).view(3, 3)
Expand Down
12 changes: 6 additions & 6 deletions mace/data/neighborhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@ def get_neighborhood(
max_positions = np.max(np.absolute(positions)) + 1
# Extend cell in non-periodic directions
# For models with more than 5 layers, the multiplicative constant needs to be increased.
temp_cell = np.copy(cell)
# temp_cell = np.copy(cell)
if not pbc_x:
temp_cell[0, :] = max_positions * 5 * cutoff * identity[0, :]
cell[0, :] = max_positions * 5 * cutoff * identity[0, :]
if not pbc_y:
temp_cell[1, :] = max_positions * 5 * cutoff * identity[1, :]
cell[1, :] = max_positions * 5 * cutoff * identity[1, :]
if not pbc_z:
temp_cell[2, :] = max_positions * 5 * cutoff * identity[2, :]
cell[2, :] = max_positions * 5 * cutoff * identity[2, :]

sender, receiver, unit_shifts = neighbour_list(
quantities="ijS",
pbc=pbc,
cell=temp_cell,
cell=cell,
positions=positions,
cutoff=cutoff,
# self_interaction=True, # we want edges from atom to itself in different periodic images
Expand All @@ -63,4 +63,4 @@ def get_neighborhood(
# D = positions[j]-positions[i]+S.dot(cell)
shifts = np.dot(unit_shifts, cell) # [n_edges, 3]

return edge_index, shifts, unit_shifts
return edge_index, shifts, unit_shifts, cell
42 changes: 27 additions & 15 deletions mace/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,34 +114,34 @@ def conditional_mse_forces(ref: Batch, pred: TensorDict) -> torch.Tensor:


def conditional_huber_forces(
ref: Batch, pred: TensorDict, huber_delta: float
ref_forces: Batch, pred_forces: TensorDict, huber_delta: float
) -> torch.Tensor:
# Define the multiplication factors for each condition
factors = huber_delta * torch.tensor([1.0, 0.7, 0.4, 0.1])

# Apply multiplication factors based on conditions
c1 = torch.norm(ref["forces"], dim=-1) < 100
c2 = (torch.norm(ref["forces"], dim=-1) >= 100) & (
torch.norm(ref["forces"], dim=-1) < 200
c1 = torch.norm(ref_forces, dim=-1) < 100
c2 = (torch.norm(ref_forces, dim=-1) >= 100) & (
torch.norm(ref_forces, dim=-1) < 200
)
c3 = (torch.norm(ref["forces"], dim=-1) >= 200) & (
torch.norm(ref["forces"], dim=-1) < 300
c3 = (torch.norm(ref_forces, dim=-1) >= 200) & (
torch.norm(ref_forces, dim=-1) < 300
)
c4 = ~(c1 | c2 | c3)

se = torch.zeros_like(pred["forces"])
se = torch.zeros_like(pred_forces)

se[c1] = torch.nn.functional.huber_loss(
ref["forces"][c1], pred["forces"][c1], reduction="none", delta=factors[0]
ref_forces[c1], pred_forces[c1], reduction="none", delta=factors[0]
)
se[c2] = torch.nn.functional.huber_loss(
ref["forces"][c2], pred["forces"][c2], reduction="none", delta=factors[1]
ref_forces[c2], pred_forces[c2], reduction="none", delta=factors[1]
)
se[c3] = torch.nn.functional.huber_loss(
ref["forces"][c3], pred["forces"][c3], reduction="none", delta=factors[2]
ref_forces[c3], pred_forces[c3], reduction="none", delta=factors[2]
)
se[c4] = torch.nn.functional.huber_loss(
ref["forces"][c4], pred["forces"][c4], reduction="none", delta=factors[3]
ref_forces[c4], pred_forces[c4], reduction="none", delta=factors[3]
)

return torch.mean(se)
Expand Down Expand Up @@ -273,15 +273,27 @@ def __init__(

def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor:
num_atoms = ref.ptr[1:] - ref.ptr[:-1]
configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ]
configs_energy_weight = ref.energy_weight # [n_graphs, ]
configs_forces_weight = torch.repeat_interleave(
ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1]
).unsqueeze(-1)
return (
self.energy_weight
* self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms)
* self.huber_loss(
configs_energy_weight * ref["energy"] / num_atoms,
configs_energy_weight * pred["energy"] / num_atoms,
)
+ self.forces_weight
* conditional_huber_forces(ref, pred, huber_delta=self.huber_delta)
* conditional_huber_forces(
configs_forces_weight * ref["forces"],
configs_forces_weight * pred["forces"],
huber_delta=self.huber_delta,
)
+ self.stress_weight
* self.huber_loss(
ref["stress"],
pred["stress"],
configs_stress_weight * ref["stress"],
configs_stress_weight * pred["stress"],
)
)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_basic(self):
]
)

indices, shifts, unit_shifts = get_neighborhood(positions, cutoff=1.5)
indices, shifts, unit_shifts, _ = get_neighborhood(positions, cutoff=1.5)
assert indices.shape == (2, 4)
assert shifts.shape == (4, 3)
assert unit_shifts.shape == (4, 3)
Expand All @@ -158,7 +158,7 @@ def test_signs(self):
)

cell = np.array([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
edge_index, shifts, unit_shifts = get_neighborhood(
edge_index, shifts, unit_shifts, _ = get_neighborhood(
positions, cutoff=3.5, pbc=(True, False, False), cell=cell
)
num_edges = 10
Expand All @@ -172,7 +172,7 @@ def test_periodic_edge():
atoms = ase.build.bulk("Cu", "fcc")
dist = np.linalg.norm(atoms.cell[0]).item()
config = config_from_atoms(atoms)
edge_index, shifts, _ = get_neighborhood(
edge_index, shifts, _, _ = get_neighborhood(
config.positions, cutoff=1.05 * dist, pbc=(True, True, True), cell=config.cell
)
sender, receiver = edge_index
Expand All @@ -190,7 +190,7 @@ def test_half_periodic():
atoms = ase.build.fcc111("Al", size=(3, 3, 1), vacuum=0.0)
assert all(atoms.pbc == (True, True, False))
config = config_from_atoms(atoms) # first shell dist is 2.864A
edge_index, shifts, _ = get_neighborhood(
edge_index, shifts, _, _ = get_neighborhood(
config.positions, cutoff=2.9, pbc=(True, True, False), cell=config.cell
)
sender, receiver = edge_index
Expand Down

0 comments on commit 6d768c2

Please sign in to comment.