Skip to content

Commit

Permalink
Merge pull request #602 from janosh/isolate-mace-mp-download
Browse files Browse the repository at this point in the history
Split `download_mace_mp_checkpoint` out of `mace_mp`
  • Loading branch information
ilyes319 authored Sep 26, 2024
2 parents 73a3bf6 + 5c0a344 commit ec288b2
Showing 1 changed file with 77 additions and 68 deletions.
145 changes: 77 additions & 68 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,51 @@
)


def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:
"""
Downloads or locates the MACE-MP checkpoint file.
Args:
model (str, optional): Path to the model or size specification.
Defaults to None which uses the medium model.
Returns:
str: Path to the downloaded (or cached, if previously loaded) checkpoint file.
"""
if model in (None, "medium") and os.path.isfile(local_model_path):
return local_model_path

urls = {
"small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
"medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
"large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model",
}

checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)

cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"

if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
print(f"Downloading MACE model from {checkpoint_url!r}")
_, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Model download failed, please check the URL {checkpoint_url}"
)
print(f"Cached MACE model to {cached_model_path}")

return cached_model_path


def mace_mp(
model: Union[str, Path] = None,
device: str = "",
Expand Down Expand Up @@ -42,59 +87,22 @@ def mace_mp(
model (str, optional): Path to the model. Defaults to None which first checks for
a local model and then downloads the default model from figshare. Specify "small",
"medium" or "large" to download a smaller or larger model from figshare.
device (str, optional): Device to use for the model. Defaults to "cuda".
device (str, optional): Device to use for the model. Defaults to "cuda" if available.
default_dtype (str, optional): Default dtype for the model. Defaults to "float32".
dispersion (bool, optional): Whether to use D3 dispersion corrections. Defaults to False.
damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ).
dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections.
dispersion_cutoff (float, optional): Cutoff radius in Bhor for D3 dispersion corrections.
dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections.
**kwargs: Passed to MACECalculator and TorchDFTD3Calculator.
Returns:
MACECalculator: trained on the MPtrj dataset (unless model otherwise specified).
"""
if model in (None, "medium") and os.path.isfile(local_model_path):
model = local_model_path
print(
f"Using local medium Materials Project MACE model for MACECalculator {model}"
)
elif model in (None, "small", "medium", "large") or str(model).startswith("https:"):
try:
# checkpoints release: https://github.com/ACEsuit/mace-mp/releases/tag/mace_mp_0
urls = dict(
small="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", # 2023-12-10-mace-128-L0_energy_epoch-249.model
medium="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", # 2023-12-03-mace-128-L1_epoch-199.model
large="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", # MACE_MPtrj_2022.9.model
)
checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
else model
)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}")
_, http_msg = urllib.request.urlretrieve(
checkpoint_url, cached_model_path
)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Model download failed, please check the URL {checkpoint_url}"
)
print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path
msg = f"Using Materials Project MACE for MACECalculator with {model}"
print(msg)
except Exception as exc:
raise RuntimeError(
"Model download failed and no local model found"
) from exc
try:
model_path = download_mace_mp_checkpoint(model)
print(f"Using Materials Project MACE for MACECalculator with {model_path}")
except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc

device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if default_dtype == "float64":
Expand All @@ -105,32 +113,33 @@ def mace_mp(
print(
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
)

mace_calc = MACECalculator(
model_paths=model, device=device, default_dtype=default_dtype, **kwargs
model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs
)
d3_calc = None
if dispersion:
gh_url = "https://github.com/pfnet-research/torch-dftd"
try:
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
except ImportError as exc:
raise RuntimeError(
f"Please install torch-dftd to use dispersion corrections (see {gh_url} from {exc})"
) from exc
print(
f"Using TorchDFTD3Calculator for D3 dispersion corrections (see {gh_url})"
)
dtype = torch.float32 if default_dtype == "float32" else torch.float64
d3_calc = TorchDFTD3Calculator(
device=device,
damping=damping,
dtype=dtype,
xc=dispersion_xc,
cutoff=dispersion_cutoff,
**kwargs,
)
calc = mace_calc if not dispersion else SumCalculator([mace_calc, d3_calc])
return calc

if not dispersion:
return mace_calc

try:
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
except ImportError as exc:
raise RuntimeError(
"Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)"
) from exc

print("Using TorchDFTD3Calculator for D3 dispersion corrections")
dtype = torch.float32 if default_dtype == "float32" else torch.float64
d3_calc = TorchDFTD3Calculator(
device=device,
damping=damping,
dtype=dtype,
xc=dispersion_xc,
cutoff=dispersion_cutoff,
**kwargs,
)

return SumCalculator([mace_calc, d3_calc])


def mace_off(
Expand Down

0 comments on commit ec288b2

Please sign in to comment.