Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix swa bug and remove e3nn fixed version #589

Merged
merged 18 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ repos:
'--disable=missing-class-docstring',
'--disable=missing-function-docstring',
'--disable=too-many-arguments',
'--disable=too-many-positional-arguments',
'--disable=too-many-locals',
'--disable=not-callable',
'--disable=logging-fstring-interpolation',
Expand Down
149 changes: 80 additions & 69 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,51 @@
)


def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:
"""
Downloads or locates the MACE-MP checkpoint file.

Args:
model (str, optional): Path to the model or size specification.
Defaults to None which uses the medium model.

Returns:
str: Path to the downloaded (or cached, if previously loaded) checkpoint file.
"""
if model in (None, "medium") and os.path.isfile(local_model_path):
return local_model_path

urls = {
"small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
"medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
"large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model",
}

checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)

cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"

if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
print(f"Downloading MACE model from {checkpoint_url!r}")
_, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Model download failed, please check the URL {checkpoint_url}"
)
print(f"Cached MACE model to {cached_model_path}")

return cached_model_path


def mace_mp(
model: Union[str, Path] = None,
device: str = "",
Expand Down Expand Up @@ -42,59 +87,22 @@ def mace_mp(
model (str, optional): Path to the model. Defaults to None which first checks for
a local model and then downloads the default model from figshare. Specify "small",
"medium" or "large" to download a smaller or larger model from figshare.
device (str, optional): Device to use for the model. Defaults to "cuda".
device (str, optional): Device to use for the model. Defaults to "cuda" if available.
default_dtype (str, optional): Default dtype for the model. Defaults to "float32".
dispersion (bool, optional): Whether to use D3 dispersion corrections. Defaults to False.
damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ).
dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections.
dispersion_cutoff (float, optional): Cutoff radius in Bhor for D3 dispersion corrections.
dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections.
**kwargs: Passed to MACECalculator and TorchDFTD3Calculator.

Returns:
MACECalculator: trained on the MPtrj dataset (unless model otherwise specified).
"""
if model in (None, "medium") and os.path.isfile(local_model_path):
model = local_model_path
print(
f"Using local medium Materials Project MACE model for MACECalculator {model}"
)
elif model in (None, "small", "medium", "large") or str(model).startswith("https:"):
try:
# checkpoints release: https://github.com/ACEsuit/mace-mp/releases/tag/mace_mp_0
urls = dict(
small="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", # 2023-12-10-mace-128-L0_energy_epoch-249.model
medium="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", # 2023-12-03-mace-128-L1_epoch-199.model
large="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", # MACE_MPtrj_2022.9.model
)
checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}")
_, http_msg = urllib.request.urlretrieve(
checkpoint_url, cached_model_path
)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Model download failed, please check the URL {checkpoint_url}"
)
print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path
msg = f"Using Materials Project MACE for MACECalculator with {model}"
print(msg)
except Exception as exc:
raise RuntimeError(
"Model download failed and no local model found"
) from exc
try:
model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}")
except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc

device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if default_dtype == "float64":
Expand All @@ -105,32 +113,33 @@ def mace_mp(
print(
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
)

mace_calc = MACECalculator(
model_paths=model, device=device, default_dtype=default_dtype, **kwargs
model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs
)
d3_calc = None
if dispersion:
gh_url = "https://github.com/pfnet-research/torch-dftd"
try:
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
except ImportError as exc:
raise RuntimeError(
f"Please install torch-dftd to use dispersion corrections (see {gh_url} from {exc})"
) from exc
print(
f"Using TorchDFTD3Calculator for D3 dispersion corrections (see {gh_url})"
)
dtype = torch.float32 if default_dtype == "float32" else torch.float64
d3_calc = TorchDFTD3Calculator(
device=device,
damping=damping,
dtype=dtype,
xc=dispersion_xc,
cutoff=dispersion_cutoff,
**kwargs,
)
calc = mace_calc if not dispersion else SumCalculator([mace_calc, d3_calc])
return calc

if not dispersion:
return mace_calc

try:
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
except ImportError as exc:
raise RuntimeError(
"Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)"
) from exc

print("Using TorchDFTD3Calculator for D3 dispersion corrections")
dtype = torch.float32 if default_dtype == "float32" else torch.float64
d3_calc = TorchDFTD3Calculator(
device=device,
damping=damping,
dtype=dtype,
xc=dispersion_xc,
cutoff=dispersion_cutoff,
**kwargs,
)

return SumCalculator([mace_calc, d3_calc])


def mace_off(
Expand Down Expand Up @@ -227,4 +236,6 @@ def mace_anicc(
print(
"Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322"
)
return MACECalculator(model_path, device=device, default_dtype="float64")
return MACECalculator(
model_paths=model_path, device=device, default_dtype="float64"
)
97 changes: 67 additions & 30 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
###########################################################################################


import logging
from glob import glob
from pathlib import Path
from typing import Union
Expand All @@ -18,7 +19,7 @@
from mace.modules.utils import extract_invariant
from mace.tools import torch_geometric, torch_tools, utils
from mace.tools.compile import prepare
from mace.tools.scripts_utils import extract_load
from mace.tools.scripts_utils import extract_model


def get_model_dtype(model: torch.nn.Module) -> torch.dtype:
Expand Down Expand Up @@ -49,8 +50,9 @@ class MACECalculator(Calculator):

def __init__(
self,
model_paths: Union[list, str],
device: str,
model_paths: Union[list, str] | None = None,
device: str | None = None,
models: Union[list[torch.nn.Module], torch.nn.Module] | None = None,
energy_units_to_eV: float = 1.0,
length_units_to_A: float = 1.0,
default_dtype="",
Expand All @@ -61,6 +63,24 @@ def __init__(
**kwargs,
):
Calculator.__init__(self, **kwargs)

if "model_path" in kwargs:
deprecation_message = (
"'model_path' argument is deprecated, please use 'model_paths'"
)
if model_paths is None:
logging.warning(f"{deprecation_message} in the future.")
model_paths = kwargs["model_path"]
else:
raise ValueError(
f"both 'model_path' and 'model_paths' given, {deprecation_message} only."
)

if (model_paths is None) == (models is None):
raise ValueError(
"Exactly one of 'model_paths' or 'models' must be provided"
)

self.results = {}

self.model_type = model_type
Expand Down Expand Up @@ -89,53 +109,70 @@ def __init__(
f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported"
)

if "model_path" in kwargs:
print("model_path argument deprecated, use model_paths")
model_paths = kwargs["model_path"]

if isinstance(model_paths, str):
# Find all models that satisfy the wildcard (e.g. mace_model_*.pt)
model_paths_glob = glob(model_paths)
if len(model_paths_glob) == 0:
raise ValueError(f"Couldn't find MACE model files: {model_paths}")
model_paths = model_paths_glob
elif isinstance(model_paths, Path):
model_paths = [model_paths]
if len(model_paths) == 0:
raise ValueError("No mace file names supplied")
self.num_models = len(model_paths)
if len(model_paths) > 1:
print(f"Running committee mace with {len(model_paths)} models")
if model_paths is not None:
if isinstance(model_paths, str):
# Find all models that satisfy the wildcard (e.g. mace_model_*.pt)
model_paths_glob = glob(model_paths)

if len(model_paths_glob) == 0:
raise ValueError(f"Couldn't find MACE model files: {model_paths}")

model_paths = model_paths_glob
elif isinstance(model_paths, Path):
model_paths = [model_paths]

if len(model_paths) == 0:
raise ValueError("No mace file names supplied")
self.num_models = len(model_paths)

# Load models from files
self.models = [
torch.load(f=model_path, map_location=device)
for model_path in model_paths
]

elif models is not None:
if not isinstance(models, list):
models = [models]

if len(models) == 0:
raise ValueError("No models supplied")

self.models = models
self.num_models = len(models)

if self.num_models > 1:
print(f"Running committee mace with {self.num_models} models")

if model_type in ["MACE", "EnergyDipoleMACE"]:
self.implemented_properties.extend(
["energies", "energy_var", "forces_comm", "stress_var"]
)
elif model_type == "DipoleMACE":
self.implemented_properties.extend(["dipole_var"])

if compile_mode is not None:
print(f"Torch compile is enabled with mode: {compile_mode}")
self.models = [
torch.compile(
prepare(extract_load)(f=model_path, map_location=device),
prepare(extract_model)(model=model, map_location=device),
mode=compile_mode,
fullgraph=fullgraph,
)
for model_path in model_paths
for model in models
]
self.use_compile = True
else:
self.models = [
torch.load(f=model_path, map_location=device)
for model_path in model_paths
]
self.use_compile = False

# Ensure all models are on the same device
for model in self.models:
model.to(device) # shouldn't be necessary but seems to help with GPU
model.to(device)

r_maxs = [model.r_max.cpu() for model in self.models]
r_maxs = np.array(r_maxs)
assert np.all(
r_maxs == r_maxs[0]
), "committee r_max are not all the same {' '.join(r_maxs)}"
if not np.all(r_maxs == r_maxs[0]):
raise ValueError(f"committee r_max are not all the same {' '.join(r_maxs)}")
self.r_max = float(r_maxs[0])

self.device = torch_tools.init_device(device)
Expand Down
4 changes: 2 additions & 2 deletions mace/cli/active_learning_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def run(args: argparse.Namespace) -> None:
atoms_index = args.config_index

mace_calc = MACECalculator(
mace_fname,
args.device,
model_paths=mace_fname,
device=args.device,
default_dtype=args.default_dtype,
)

Expand Down
5 changes: 4 additions & 1 deletion mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def select_head(model):
def main():
args = parse_args()
model_path = args.model_path # takes model name as command-line input
model = torch.load(model_path)
model = torch.load(
model_path,
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
model = model.double().to("cpu")

if args.head is None:
Expand Down
Loading
Loading