Skip to content

Commit

Permalink
Merge pull request #224 from laserkelvin/multitask-ase-interface
Browse files Browse the repository at this point in the history
Multitask ASE interface
  • Loading branch information
laserkelvin authored May 28, 2024
2 parents f24fba3 + 22c5f36 commit 92f1600
Show file tree
Hide file tree
Showing 5 changed files with 450 additions and 22 deletions.
42 changes: 42 additions & 0 deletions examples/interfaces/ase_multitask_from_pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from ase import Atoms, units
from ase.md.verlet import VelocityVerlet

from matsciml.interfaces.ase import MatSciMLCalculator
from matsciml.interfaces.ase.multitask import AverageTasks
from matsciml.datasets.transforms import (
PeriodicPropertiesTransform,
PointCloudToGraphTransform,
)

"""
Demonstrates setting up a calculator from a pretrained
multitask/multidata module, using an averaging strategy to
merge output heads.
As an example, if we trained force regression on multiple datasets
simultaneously, we would average the outputs from each "dataset",
similar to an ensemble prediction without any special weighting.
Substitute 'model.ckpt' for the path to a checkpoint file.
"""

d = 2.9
L = 10.0

atoms = Atoms("C", positions=[[0, L / 2, L / 2]], cell=[d, L, L], pbc=[1, 0, 0])

calc = MatSciMLCalculator.from_pretrained_force_regression(
"model.ckpt",
transforms=[
PeriodicPropertiesTransform(6.0, True),
PointCloudToGraphTransform("pyg"),
],
multitask_strategy=AverageTasks(), # also can be specified as a string
)
# set the calculator to matsciml
atoms.calc = calc
# run the simulation for 100 timesteps, with 5 femtosecond timesteps
dyn = VelocityVerlet(atoms, timestep=5 * units.fs, logfile="md.log")
dyn.run(100)
44 changes: 30 additions & 14 deletions matsciml/interfaces/ase/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
MultiTaskLitModule,
)
from matsciml.datasets.transforms.base import AbstractDataTransform
from matsciml.interfaces.ase import multitask as mt

__all__ = ["MatSciMLCalculator"]

Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
atoms: Atoms | None = None,
directory=".",
conversion_factor: float | dict[str, float] = 1.0,
multitask_strategy: str | Callable | mt.AbstractStrategy = "AverageTasks",
**kwargs,
):
"""
Expand Down Expand Up @@ -172,6 +174,14 @@ def __init__(
self.task_module = task_module
self.transforms = transforms
self.conversion_factor = conversion_factor
if isinstance(multitask_strategy, str):
cls_name = getattr(mt, multitask_strategy, None)
if cls_name is None:
raise NameError(
f"Invalid multitask strategy name; supported strategies are {mt.__all__}"
)
multitask_strategy = cls_name()
self.multitask_strategy = multitask_strategy

@property
def conversion_factor(self) -> dict[str, float]:
Expand Down Expand Up @@ -238,20 +248,26 @@ def calculate(
# get into format ready for matsciml model
data_dict = self._format_pipeline(atoms)
# run the data structure through the model
output = self.task_module(data_dict)
# add outputs to self.results as expected by ase
if "energy" in output:
self.results["energy"] = output["energy"].detach().item()
if "force" in output:
self.results["forces"] = output["force"].detach().numpy()
if "stress" in output:
self.results["stress"] = output["stress"].detach().numpy()
if "dipole" in output:
self.results["dipole"] = output["dipole"].detach().numpy()
if len(self.results) == 0:
raise RuntimeError(
f"No expected properties were written. Output dict: {output}"
)
if isinstance(self.task_module, MultiTaskLitModule):
output = self.task_module.ase_calculate(data_dict)
# use a more complicated parser for multitasks
results = self.multitask_strategy(output, self.task_module)
self.results = results
else:
output = self.task_module(data_dict)
# add outputs to self.results as expected by ase
if "energy" in output:
self.results["energy"] = output["energy"].detach().item()
if "force" in output:
self.results["forces"] = output["force"].detach().numpy()
if "stress" in output:
self.results["stress"] = output["stress"].detach().numpy()
if "dipole" in output:
self.results["dipole"] = output["dipole"].detach().numpy()
if len(self.results) == 0:
raise RuntimeError(
f"No expected properties were written. Output dict: {output}"
)
# perform optional unit conversions
for key, value in self.conversion_factor.items():
if key in self.results:
Expand Down
146 changes: 146 additions & 0 deletions matsciml/interfaces/ase/multitask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import annotations

from abc import abstractmethod, ABC

import torch
import numpy as np

from matsciml.models.base import (
MultiTaskLitModule,
)
from matsciml.common.types import DataDict


__task_property_mapping__ = {
"ScalarRegressionTask": ["energy", "dipole"],
"ForceRegressionTask": ["energy", "force"],
"GradFreeForceRegressionTask": ["force"],
}


__all__ = ["AverageTasks"]


class AbstractStrategy(ABC):
@abstractmethod
def merge_outputs(
self,
outputs: dict[str, dict[str, float | torch.Tensor]]
| dict[str, list[float | torch.Tensor]],
*args,
**kwargs,
) -> dict[str, float | np.ndarray]: ...

def parse_outputs(
self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs
) -> tuple[
dict[str, dict[str, float | torch.Tensor]],
dict[str, list[float | torch.Tensor]],
]:
"""
Map the task results into their appropriate fields.
Expected output looks like:
{"IS2REDataset": {"energy": ..., "forces": ...}, ...}
Parameters
----------
output_dict : DataDict
Multitask/multidata output from the ``MultiTaskLitModule``
forward pass.
task : MultiTaskLitModule
Instance of the task module. This allows access to the
``task.task_map``, which tells us which dataset/subtask
is mapped together.
Returns
-------
dict[str, dict[str, float | torch.Tensor]]
Dictionary mapping of results per dataset. The subdicts
correspond to the extracted outputs, per subtask (e.g.
energy/force from the IS2REDataset head).
dict[str, list[float | torch.Tensor]]
For convenience, this provides the same data without
differentiating between datasets, and instead, sorts
them by the property name (e.g. {"energy": [...]}).
Raises
------
RuntimeError:
When no subresults are returned for a dataset that is
expected to have something on the basis that a task
_should_ produce something, e.g. ``ForceRegressionTask``
should yield energy/force, and if it doesn't produce
anything, something is wrong.
"""
results = {}
per_key_results = {}
# loop over the task map
for dset_name in task.task_map.keys():
for subtask_name, subtask in task.task_map[dset_name].items():
sub_results = {}
pos_fields = __task_property_mapping__.get(subtask_name, None)
if pos_fields is None:
continue
else:
for key in pos_fields:
output = output_dict[dset_name][subtask_name].get(key, None)
# this means the task _can_ output the key but was
# not included in the actual training task keys
if output is None:
continue
if isinstance(output, torch.Tensor):
output = output.detach()
if key == "energy":
# squeeze is applied just in case we have too many
# extra dimensions
output = output.squeeze().item()
sub_results[key] = output
# add to per_key_results as another sorting
if key not in per_key_results:
per_key_results[key] = []
per_key_results[key].append(output)
if len(sub_results) == 0:
raise RuntimeError(
f"Expected {subtask_name} to have {pos_fields} but got nothing."
)
results[dset_name] = sub_results
return results, per_key_results

@abstractmethod
def run(self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs): ...

def __call__(
self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs
) -> dict[str, float | np.ndarray]:
aggregated_results = self.run(output_dict, task, *args, **kwargs)
# TODO: homogenize keys so we don't have to do stuff like this :P
if "force" in aggregated_results:
aggregated_results["forces"] = aggregated_results["force"]
return aggregated_results


class AverageTasks(AbstractStrategy):
def merge_outputs(
self, outputs: dict[str, list[float | torch.Tensor]], *args, **kwargs
) -> dict[str, float | np.ndarray]:
joined_results = {}
for key, results in outputs.items():
if isinstance(results[0], float):
merged_results = sum(results) / len(results)
elif isinstance(results[0], torch.Tensor):
results = torch.stack(results, dim=0)
merged_results = results.mean(dim=0).numpy()
else:
raise TypeError(
f"Only floats and tensors are supported for merging; got {type(results[0])} for key {key}."
)
joined_results[key] = merged_results
return joined_results

def run(
self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs
) -> dict[str, float | np.ndarray]:
_, per_key_results = self.parse_outputs(output_dict, task)
aggregated_results = self.merge_outputs(per_key_results)
return aggregated_results
Loading

0 comments on commit 92f1600

Please sign in to comment.