Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multitask ASE interface #224

Merged
merged 19 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
c965377
feat: added a base method for extracting multitask results
laserkelvin May 16, 2024
aed05b2
refactor: setting signature for merge output function
laserkelvin May 17, 2024
0f3066d
refactor: returning per-key results
laserkelvin May 17, 2024
a0cb969
refactor: adding abstract run and __call__
laserkelvin May 17, 2024
907a403
feat: added merge method for averaging method
laserkelvin May 17, 2024
d542fa3
feat: added run call for averaging strategy
laserkelvin May 17, 2024
db53652
feat: defining __all__ in multitask strategies
laserkelvin May 17, 2024
8c0ed40
refactor: adding multi task strategy interface in calculator
laserkelvin May 17, 2024
17aa627
refactor: adding multi task strategy application to calculate
laserkelvin May 17, 2024
c982773
test: added unit tests for multi task aggregations
laserkelvin May 17, 2024
e5cdb8e
test: added tests to check force output shape
laserkelvin May 17, 2024
951c124
refactor: added temporary step to ensure force key consistency
laserkelvin May 17, 2024
3bcc446
refactor: making multitask output keys refer to task class names
laserkelvin May 17, 2024
b0d161e
test: updating test to make things work
laserkelvin May 17, 2024
c7c25ff
fix: correcting graph key retrieval from batch
laserkelvin May 24, 2024
4e87c5d
refactor: writing a dedicated method for multitask ase inference
laserkelvin May 24, 2024
2f8894c
fix: correcting graph get retrieval
laserkelvin May 24, 2024
6e54dad
fix: patches input grad toggling based on incoming batch
laserkelvin May 24, 2024
22c5f36
script: added pretrained example from multitask
laserkelvin May 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions matsciml/interfaces/ase/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
MultiTaskLitModule,
)
from matsciml.datasets.transforms.base import AbstractDataTransform
from matsciml.interfaces.ase import multitask as mt

__all__ = ["MatSciMLCalculator"]

Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
atoms: Atoms | None = None,
directory=".",
conversion_factor: float | dict[str, float] = 1.0,
multitask_strategy: str | Callable | mt.AbstractStrategy = "AverageTasks",
**kwargs,
):
"""
Expand Down Expand Up @@ -172,6 +174,14 @@ def __init__(
self.task_module = task_module
self.transforms = transforms
self.conversion_factor = conversion_factor
if isinstance(multitask_strategy, str):
cls_name = getattr(mt, multitask_strategy, None)
if cls_name is None:
raise NameError(
f"Invalid multitask strategy name; supported strategies are {mt.__all__}"
)
multitask_strategy = cls_name()
self.multitask_strategy = multitask_strategy

@property
def conversion_factor(self) -> dict[str, float]:
Expand Down Expand Up @@ -238,20 +248,26 @@ def calculate(
# get into format ready for matsciml model
data_dict = self._format_pipeline(atoms)
# run the data structure through the model
output = self.task_module(data_dict)
# add outputs to self.results as expected by ase
if "energy" in output:
self.results["energy"] = output["energy"].detach().item()
if "force" in output:
self.results["forces"] = output["force"].detach().numpy()
if "stress" in output:
self.results["stress"] = output["stress"].detach().numpy()
if "dipole" in output:
self.results["dipole"] = output["dipole"].detach().numpy()
if len(self.results) == 0:
raise RuntimeError(
f"No expected properties were written. Output dict: {output}"
)
if isinstance(self.task_module, MultiTaskLitModule):
output = self.task_module.ase_calculate(data_dict)
# use a more complicated parser for multitasks
results = self.multitask_strategy(output, self.task_module)
self.results = results
else:
output = self.task_module(data_dict)
# add outputs to self.results as expected by ase
if "energy" in output:
self.results["energy"] = output["energy"].detach().item()
if "force" in output:
self.results["forces"] = output["force"].detach().numpy()
if "stress" in output:
self.results["stress"] = output["stress"].detach().numpy()
if "dipole" in output:
self.results["dipole"] = output["dipole"].detach().numpy()
if len(self.results) == 0:
raise RuntimeError(
f"No expected properties were written. Output dict: {output}"
)
# perform optional unit conversions
for key, value in self.conversion_factor.items():
if key in self.results:
Expand Down
146 changes: 146 additions & 0 deletions matsciml/interfaces/ase/multitask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import annotations

from abc import abstractmethod, ABC

import torch
import numpy as np

from matsciml.models.base import (
MultiTaskLitModule,
)
from matsciml.common.types import DataDict


__task_property_mapping__ = {
"ScalarRegressionTask": ["energy", "dipole"],
"ForceRegressionTask": ["energy", "force"],
"GradFreeForceRegressionTask": ["force"],
}


__all__ = ["AverageTasks"]


class AbstractStrategy(ABC):
@abstractmethod
def merge_outputs(
self,
outputs: dict[str, dict[str, float | torch.Tensor]]
| dict[str, list[float | torch.Tensor]],
*args,
**kwargs,
) -> dict[str, float | np.ndarray]: ...

def parse_outputs(
self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs
) -> tuple[
dict[str, dict[str, float | torch.Tensor]],
dict[str, list[float | torch.Tensor]],
]:
"""
Map the task results into their appropriate fields.

Expected output looks like:
{"IS2REDataset": {"energy": ..., "forces": ...}, ...}

Parameters
----------
output_dict : DataDict
Multitask/multidata output from the ``MultiTaskLitModule``
forward pass.
task : MultiTaskLitModule
Instance of the task module. This allows access to the
``task.task_map``, which tells us which dataset/subtask
is mapped together.

Returns
-------
dict[str, dict[str, float | torch.Tensor]]
Dictionary mapping of results per dataset. The subdicts
correspond to the extracted outputs, per subtask (e.g.
energy/force from the IS2REDataset head).
dict[str, list[float | torch.Tensor]]
For convenience, this provides the same data without
differentiating between datasets, and instead, sorts
them by the property name (e.g. {"energy": [...]}).

Raises
------
RuntimeError:
When no subresults are returned for a dataset that is
expected to have something on the basis that a task
_should_ produce something, e.g. ``ForceRegressionTask``
should yield energy/force, and if it doesn't produce
anything, something is wrong.
"""
results = {}
per_key_results = {}
# loop over the task map
for dset_name in task.task_map.keys():
for subtask_name, subtask in task.task_map[dset_name].items():
sub_results = {}
pos_fields = __task_property_mapping__.get(subtask_name, None)
if pos_fields is None:
continue
else:
for key in pos_fields:
output = output_dict[dset_name][subtask_name].get(key, None)
# this means the task _can_ output the key but was
# not included in the actual training task keys
if output is None:
continue
if isinstance(output, torch.Tensor):
output = output.detach()
if key == "energy":
# squeeze is applied just in case we have too many
# extra dimensions
output = output.squeeze().item()
sub_results[key] = output
# add to per_key_results as another sorting
if key not in per_key_results:
per_key_results[key] = []
per_key_results[key].append(output)
if len(sub_results) == 0:
raise RuntimeError(
f"Expected {subtask_name} to have {pos_fields} but got nothing."
)
results[dset_name] = sub_results
return results, per_key_results

@abstractmethod
def run(self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs): ...

def __call__(
self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs
) -> dict[str, float | np.ndarray]:
aggregated_results = self.run(output_dict, task, *args, **kwargs)
# TODO: homogenize keys so we don't have to do stuff like this :P
if "force" in aggregated_results:
aggregated_results["forces"] = aggregated_results["force"]
return aggregated_results


class AverageTasks(AbstractStrategy):
def merge_outputs(
self, outputs: dict[str, list[float | torch.Tensor]], *args, **kwargs
) -> dict[str, float | np.ndarray]:
joined_results = {}
for key, results in outputs.items():
if isinstance(results[0], float):
merged_results = sum(results) / len(results)
elif isinstance(results[0], torch.Tensor):
results = torch.stack(results, dim=0)
merged_results = results.mean(dim=0).numpy()
else:
raise TypeError(
f"Only floats and tensors are supported for merging; got {type(results[0])} for key {key}."
)
joined_results[key] = merged_results
return joined_results

def run(
self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs
) -> dict[str, float | np.ndarray]:
_, per_key_results = self.parse_outputs(output_dict, task)
aggregated_results = self.merge_outputs(per_key_results)
return aggregated_results
175 changes: 175 additions & 0 deletions matsciml/interfaces/ase/tests/test_multi_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from __future__ import annotations

import pytest
import torch
import numpy as np
from ase import Atoms, units
from ase.md import VelocityVerlet

from matsciml.models.pyg import EGNN
from matsciml.models.base import (
MultiTaskLitModule,
ScalarRegressionTask,
ForceRegressionTask,
)
from matsciml.datasets.transforms import (
PeriodicPropertiesTransform,
PointCloudToGraphTransform,
)
from matsciml.interfaces.ase import multitask as mt
from matsciml.interfaces.ase import MatSciMLCalculator


@pytest.fixture
def test_pbc() -> Atoms:
pos = np.random.normal(0.0, 1.0, size=(16, 3)) * 10.0
atomic_numbers = np.random.randint(1, 100, size=(16,))
cell = np.eye(3).astype(float)
return Atoms(numbers=atomic_numbers, positions=pos, cell=cell)


@pytest.fixture
def pbc_transform() -> list:
return [PeriodicPropertiesTransform(6.0, True), PointCloudToGraphTransform("pyg")]


@pytest.fixture
def egnn_args():
return {"hidden_dim": 32, "output_dim": 32}


@pytest.fixture
def single_data_multi_task_combo(egnn_args):
output = {
"IS2REDataset": {
"ScalarRegressionTask": {"energy": torch.rand(1, 1)},
"ForceRegressionTask": {
"energy": torch.rand(1, 1),
"force": torch.rand(32, 3),
},
}
}
task = MultiTaskLitModule(
(
"IS2REDataset",
ScalarRegressionTask(
encoder_class=EGNN, encoder_kwargs=egnn_args, task_keys=["energy"]
),
),
(
"IS2REDataset",
ForceRegressionTask(
encoder_class=EGNN,
encoder_kwargs=egnn_args,
output_kwargs={"lazy": False, "input_dim": 32},
),
),
)
return output, task


@pytest.fixture
def multi_data_multi_task_combo(egnn_args):
output = {
"IS2REDataset": {
"ScalarRegressionTask": {"energy": torch.rand(1, 1)},
"ForceRegressionTask": {
"energy": torch.rand(1, 1),
"force": torch.rand(32, 3),
},
},
"S2EFDataset": {
"ForceRegressionTask": {
"energy": torch.rand(1, 1),
"force": torch.rand(32, 3),
}
},
"AlexandriaDataset": {
"ForceRegressionTask": {
"energy": torch.rand(1, 1),
"force": torch.rand(32, 3),
}
},
}
task = MultiTaskLitModule(
(
"IS2REDataset",
ScalarRegressionTask(
encoder_class=EGNN,
encoder_kwargs=egnn_args,
task_keys=["energy"],
output_kwargs={"lazy": False, "hidden_dim": 32, "input_dim": 32},
),
),
(
"IS2REDataset",
ForceRegressionTask(
encoder_class=EGNN,
encoder_kwargs=egnn_args,
output_kwargs={"lazy": False, "hidden_dim": 32, "input_dim": 32},
),
),
(
"S2EFDataset",
ForceRegressionTask(
encoder_class=EGNN,
encoder_kwargs=egnn_args,
output_kwargs={"lazy": False, "hidden_dim": 32, "input_dim": 32},
),
),
(
"AlexandriaDataset",
ForceRegressionTask(
encoder_class=EGNN,
encoder_kwargs=egnn_args,
output_kwargs={"lazy": False, "hidden_dim": 32, "input_dim": 32},
),
),
)
return output, task


def test_average_single_data(single_data_multi_task_combo):
# unpack the fixtrure
output, task = single_data_multi_task_combo
strat = mt.AverageTasks()
# test the parsing
_, parsed_output = strat.parse_outputs(output, task)
agg_results = strat.merge_outputs(parsed_output)
end = strat(output, task)
assert end
assert agg_results
for key in ["energy", "forces"]:
assert key in end, f"{key} was missing from agg results"
assert end["forces"].shape == (32, 3)


def test_average_multi_data(multi_data_multi_task_combo):
# unpack the fixtrure
output, task = multi_data_multi_task_combo
strat = mt.AverageTasks()
# test the parsing
_, parsed_output = strat.parse_outputs(output, task)
agg_results = strat.merge_outputs(parsed_output)
end = strat(output, task)
assert end
assert agg_results
for key in ["energy", "forces"]:
assert key in end, f"{key} was missing from agg results"
assert end["forces"].shape == (32, 3)


def test_calc_multi_data(
multi_data_multi_task_combo, test_pbc: Atoms, pbc_transform: list
):
output, task = multi_data_multi_task_combo
strat = mt.AverageTasks()
calc = MatSciMLCalculator(task, multitask_strategy=strat, transforms=pbc_transform)
atoms = test_pbc.copy()
atoms.calc = calc
energy = atoms.get_potential_energy()
assert np.isfinite(energy)
forces = atoms.get_forces()
assert np.isfinite(forces).all()
dyn = VelocityVerlet(atoms, timestep=5 * units.fs, logfile="md.log")
dyn.run(3)
Loading
Loading