Skip to content

Commit

Permalink
allow missing labels in training
Browse files Browse the repository at this point in the history
  • Loading branch information
bowen-bd committed Oct 28, 2024
1 parent 817e21b commit 0da2d15
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 17 deletions.
58 changes: 44 additions & 14 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
force_loss_ratio: float = 1,
stress_loss_ratio: float = 0.1,
mag_loss_ratio: float = 0.1,
allow_missing_labels: bool = True,
optimizer: str = "Adam",
scheduler: str = "CosLR",
criterion: str = "MSE",
Expand Down Expand Up @@ -78,6 +79,9 @@ def __init__(
Default = 0.1
mag_loss_ratio (float): magmom loss ratio in loss function
Default = 0.1
allow_missing_labels (bool): whether to allow missing labels in the dataset,
missed target will not contribute to loss and MAEs
Default = True
optimizer (str): optimizer to update model. Can be "Adam", "SGD", "AdamW",
"RAdam". Default = 'Adam'
scheduler (str): learning rate scheduler. Can be "CosLR", "ExponentialLR",
Expand Down Expand Up @@ -209,6 +213,7 @@ def __init__(
force_loss_ratio=force_loss_ratio,
stress_loss_ratio=stress_loss_ratio,
mag_loss_ratio=mag_loss_ratio,
allow_missing_labels=allow_missing_labels,
**kwargs,
)
self.epochs = epochs
Expand Down Expand Up @@ -726,6 +731,7 @@ def __init__(
stress_loss_ratio: float = 0.1,
mag_loss_ratio: float = 0.1,
delta: float = 0.1,
allow_missing_labels: bool = True,
) -> None:
"""Initialize the combined loss.
Expand All @@ -745,6 +751,8 @@ def __init__(
mag_loss_ratio (float): magmom loss ratio in loss function
Default = 0.1
delta (float): delta for torch.nn.HuberLoss. Default = 0.1
allow_missing_labels (bool): whether to allow missing labels in the dataset,
missed target will not contribute to loss and MAEs
"""
super().__init__()
# Define loss criterion
Expand All @@ -771,6 +779,7 @@ def __init__(
self.mag_loss_ratio = 0
else:
self.mag_loss_ratio = mag_loss_ratio
self.allow_missing_labels = allow_missing_labels

def forward(
self,
Expand All @@ -791,25 +800,37 @@ def forward(
out = {"loss": 0.0}
# Energy
if "e" in self.target_str:
if self.is_intensive:
out["loss"] += self.energy_loss_ratio * self.criterion(
targets["e"], prediction["e"]
)
out["e_MAE"] = mae(targets["e"], prediction["e"])
out["e_MAE_size"] = prediction["e"].shape[0]
if self.allow_missing_labels:
valid_value_indices = ~torch.isnan(targets["e"])
valid_e_target = targets["e"][valid_value_indices]
valid_atoms_per_graph = prediction["atoms_per_graph"][
valid_value_indices
]
valid_e_pred = prediction["e"][valid_value_indices]
if valid_e_pred.shape == torch.Size([]):
valid_e_pred = valid_e_pred.view(1)
else:
e_per_atom_target = targets["e"] / prediction["atoms_per_graph"]
e_per_atom_pred = prediction["e"] / prediction["atoms_per_graph"]
out["loss"] += self.energy_loss_ratio * self.criterion(
e_per_atom_target, e_per_atom_pred
)
out["e_MAE"] = mae(e_per_atom_target, e_per_atom_pred)
out["e_MAE_size"] = prediction["e"].shape[0]
valid_e_target = targets["e"]
valid_atoms_per_graph = prediction["atoms_per_graph"]
valid_e_pred = prediction["e"]
if self.is_intensive:
valid_e_target = valid_e_target / valid_atoms_per_graph
valid_e_pred = valid_e_pred / valid_atoms_per_graph

out["loss"] += self.energy_loss_ratio * self.criterion(
valid_e_target, valid_e_pred
)
out["e_MAE"] = mae(valid_e_target, valid_e_pred)
out["e_MAE_size"] = prediction["e"].shape[0]

# Force
if "f" in self.target_str:
forces_pred = torch.cat(prediction["f"], dim=0)
forces_target = torch.cat(targets["f"], dim=0)
if self.allow_missing_labels:
valid_value_indices = ~torch.isnan(forces_target)
forces_target = forces_target[valid_value_indices]
forces_pred = forces_pred[valid_value_indices]
out["loss"] += self.force_loss_ratio * self.criterion(
forces_target, forces_pred
)
Expand All @@ -820,6 +841,10 @@ def forward(
if "s" in self.target_str:
stress_pred = torch.cat(prediction["s"], dim=0)
stress_target = torch.cat(targets["s"], dim=0)
if self.allow_missing_labels:
valid_value_indices = ~torch.isnan(stress_target)
stress_target = stress_target[valid_value_indices]
stress_pred = stress_pred[valid_value_indices]
out["loss"] += self.stress_loss_ratio * self.criterion(
stress_target, stress_pred
)
Expand All @@ -832,7 +857,12 @@ def forward(
m_mae_size = 0
for mag_pred, mag_target in zip(prediction["m"], targets["m"], strict=True):
# exclude structures without magmom labels
if mag_target is not None:
if self.allow_missing_labels:
if mag_target is not None and not np.isnan(mag_target).any():
mag_preds.append(mag_pred)
mag_targets.append(mag_target)
m_mae_size += mag_target.shape[0]
else:
mag_preds.append(mag_pred)
mag_targets.append(mag_target)
m_mae_size += mag_target.shape[0]
Expand Down
17 changes: 14 additions & 3 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
coords = [[0, 0, 0], [0.5, 0.5, 0.5]]
NaCl = Structure(lattice, species, coords)
structures, energies, forces, stresses, magmoms = [], [], [], [], []
for _ in range(100):
for _ in range(20):
struct = NaCl.copy()
struct.perturb(0.1)
structures.append(struct)
Expand All @@ -30,15 +30,22 @@
stresses.append(np.random.random([3, 3]))
magmoms.append(np.random.random(2))

# Create some missing labels
energies[10] = np.nan
forces[4] = (np.nan * np.ones((len(structures[4]), 3))).tolist()
stresses[6] = (np.nan * np.ones((3, 3))).tolist()
magmoms[8] = (np.nan * np.ones((len(structures[8]), 1))).tolist()

data = StructureData(
structures=structures,
energies=energies,
forces=forces,
stresses=stresses,
magmoms=magmoms,
shuffle=False,
)
train_loader, val_loader, _test_loader = get_train_val_test_loader(
data, batch_size=16, train_ratio=0.9, val_ratio=0.05
data, batch_size=4, train_ratio=0.9, val_ratio=0.05
)
chgnet = CHGNet.load()

Expand All @@ -55,6 +62,7 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
wandb_path="test/run",
wandb_init_kwargs=dict(anonymous="must"),
extra_run_config=extra_run_config,
allow_missing_labels=True,
)
trainer.train(
train_loader,
Expand All @@ -66,7 +74,9 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
for param in chgnet.composition_model.parameters():
assert param.requires_grad is False
assert tmp_path.is_dir(), "Training dir was not created"

for target_str in ["e", "f", "s", "m"]:
assert ~np.isnan(trainer.training_history[target_str]["train"]).any()
assert ~np.isnan(trainer.training_history[target_str]["val"]).any()
output_files = [file.name for file in tmp_path.iterdir()]
for prefix in ("epoch", "bestE_", "bestF_"):
n_matches = sum(file.startswith(prefix) for file in output_files)
Expand Down Expand Up @@ -147,6 +157,7 @@ def test_wandb_init(mock_wandb):
"wandb_path": "test-project/test-run",
"wandb_init_kwargs": {"tags": ["test"]},
"extra_run_config": None,
"allow_missing_labels": True,
}
mock_wandb.init.assert_called_once_with(
project="test-project", name="test-run", config=expected_config, tags=["test"]
Expand Down

0 comments on commit 0da2d15

Please sign in to comment.