From dcaa0ae5287732955b4ce9123aeb5d17b00cfd4c Mon Sep 17 00:00:00 2001 From: Hongyu Yu <74477906+Hongyu-yu@users.noreply.github.com> Date: Thu, 5 Sep 2024 09:24:33 +0800 Subject: [PATCH 01/13] Fix bug about undefined swa Bug will come out when swa is not used at the end of training. ``` mace/tools/train.py", line 262, in train if patience_counter >= patience and epoch < swa.start: AttributeError: 'NoneType' object has no attribute 'start' ``` --- mace/tools/train.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/mace/tools/train.py b/mace/tools/train.py index b38bce16..5af25456 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -259,16 +259,17 @@ def train( if valid_loss >= lowest_loss: patience_counter += 1 - if patience_counter >= patience and epoch < swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" - ) - epoch = swa.start - elif patience_counter >= patience and epoch >= swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement" - ) - break + if patience_counter >= patience: + if swa is not None and epoch < swa.start: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" + ) + epoch = swa.start + else: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement" + ) + break if save_all_checkpoints: param_context = ( ema.average_parameters() From 7a475fbed5f7e90bb9ec7612015b7909550d92c6 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 12 Sep 2024 10:28:55 +0100 Subject: [PATCH 02/13] remove e3nn version --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 13d55161..4d5419d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,7 +15,7 @@ packages = find: python_requires = >=3.7 install_requires = torch>=1.12 - e3nn==0.4.4 + e3nn numpy<2.0 opt_einsum ase From a2d3dd7a4707f583adc429e49e6c7d6ddaaddc1a Mon Sep 17 00:00:00 2001 From: Tamas K Stenczel Date: Wed, 18 Sep 2024 14:37:49 +0200 Subject: [PATCH 03/13] specify map_location in mace_create_lammps_model --- mace/cli/create_lammps_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 3f647906..2530937e 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -54,7 +54,10 @@ def select_head(model): def main(): args = parse_args() model_path = args.model_path # takes model name as command-line input - model = torch.load(model_path) + model = torch.load( + model_path, + map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ) model = model.double().to("cpu") if args.head is None: From 3610868083b34284d185d68648439d0dd7b21f46 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Fri, 20 Sep 2024 15:05:05 -0400 Subject: [PATCH 04/13] split download_mace_mp_checkpoint out of mace_mp to only download and return path to checkpoint for use with openmmml.MLPotential --- mace/calculators/foundations_models.py | 145 +++++++++++++------------ 1 file changed, 77 insertions(+), 68 deletions(-) diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index 6d29cb04..cecdb32b 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -15,6 +15,51 @@ ) +def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: + """ + Downloads or locates the MACE-MP checkpoint file. + + Args: + model (str, optional): Path to the model or size specification. + Defaults to None which uses the medium model. + + Returns: + str: Path to the downloaded (or cached, if previously loaded) checkpoint file. + """ + if model in (None, "medium") and os.path.isfile(local_model_path): + return local_model_path + + urls = { + "small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", + "medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", + "large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", + } + + checkpoint_url = ( + urls.get(model, urls["medium"]) + if model in (None, "small", "medium", "large") + else model + ) + + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = "".join( + c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" + ) + cached_model_path = f"{cache_dir}/{checkpoint_url_name}" + + if not os.path.isfile(cached_model_path): + os.makedirs(cache_dir, exist_ok=True) + print(f"Downloading MACE model from {checkpoint_url!r}") + _, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Model download failed, please check the URL {checkpoint_url}" + ) + print(f"Cached MACE model to {cached_model_path}") + + return cached_model_path + + def mace_mp( model: Union[str, Path] = None, device: str = "", @@ -42,59 +87,22 @@ def mace_mp( model (str, optional): Path to the model. Defaults to None which first checks for a local model and then downloads the default model from figshare. Specify "small", "medium" or "large" to download a smaller or larger model from figshare. - device (str, optional): Device to use for the model. Defaults to "cuda". + device (str, optional): Device to use for the model. Defaults to "cuda" if available. default_dtype (str, optional): Default dtype for the model. Defaults to "float32". dispersion (bool, optional): Whether to use D3 dispersion corrections. Defaults to False. damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ). dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections. - dispersion_cutoff (float, optional): Cutoff radius in Bhor for D3 dispersion corrections. + dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections. **kwargs: Passed to MACECalculator and TorchDFTD3Calculator. Returns: MACECalculator: trained on the MPtrj dataset (unless model otherwise specified). """ - if model in (None, "medium") and os.path.isfile(local_model_path): - model = local_model_path - print( - f"Using local medium Materials Project MACE model for MACECalculator {model}" - ) - elif model in (None, "small", "medium", "large") or str(model).startswith("https:"): - try: - # checkpoints release: https://github.com/ACEsuit/mace-mp/releases/tag/mace_mp_0 - urls = dict( - small="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", # 2023-12-10-mace-128-L0_energy_epoch-249.model - medium="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", # 2023-12-03-mace-128-L1_epoch-199.model - large="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", # MACE_MPtrj_2022.9.model - ) - checkpoint_url = ( - urls.get(model, urls["medium"]) - if model in (None, "small", "medium", "large") - else model - ) - cache_dir = os.path.expanduser("~/.cache/mace") - checkpoint_url_name = "".join( - c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" - ) - cached_model_path = f"{cache_dir}/{checkpoint_url_name}" - if not os.path.isfile(cached_model_path): - os.makedirs(cache_dir, exist_ok=True) - # download and save to disk - print(f"Downloading MACE model from {checkpoint_url!r}") - _, http_msg = urllib.request.urlretrieve( - checkpoint_url, cached_model_path - ) - if "Content-Type: text/html" in http_msg: - raise RuntimeError( - f"Model download failed, please check the URL {checkpoint_url}" - ) - print(f"Cached MACE model to {cached_model_path}") - model = cached_model_path - msg = f"Using Materials Project MACE for MACECalculator with {model}" - print(msg) - except Exception as exc: - raise RuntimeError( - "Model download failed and no local model found" - ) from exc + try: + model_path = download_mace_mp_checkpoint(model) + print(f"Using Materials Project MACE for MACECalculator with {model_path}") + except Exception as exc: + raise RuntimeError("Model download failed and no local model found") from exc device = device or ("cuda" if torch.cuda.is_available() else "cpu") if default_dtype == "float64": @@ -105,32 +113,33 @@ def mace_mp( print( "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization." ) + mace_calc = MACECalculator( - model_paths=model, device=device, default_dtype=default_dtype, **kwargs + model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs ) - d3_calc = None - if dispersion: - gh_url = "https://github.com/pfnet-research/torch-dftd" - try: - from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator - except ImportError as exc: - raise RuntimeError( - f"Please install torch-dftd to use dispersion corrections (see {gh_url} from {exc})" - ) from exc - print( - f"Using TorchDFTD3Calculator for D3 dispersion corrections (see {gh_url})" - ) - dtype = torch.float32 if default_dtype == "float32" else torch.float64 - d3_calc = TorchDFTD3Calculator( - device=device, - damping=damping, - dtype=dtype, - xc=dispersion_xc, - cutoff=dispersion_cutoff, - **kwargs, - ) - calc = mace_calc if not dispersion else SumCalculator([mace_calc, d3_calc]) - return calc + + if not dispersion: + return mace_calc + + try: + from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator + except ImportError as exc: + raise RuntimeError( + "Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)" + ) from exc + + print("Using TorchDFTD3Calculator for D3 dispersion corrections") + dtype = torch.float32 if default_dtype == "float32" else torch.float64 + d3_calc = TorchDFTD3Calculator( + device=device, + damping=damping, + dtype=dtype, + xc=dispersion_xc, + cutoff=dispersion_cutoff, + **kwargs, + ) + + return SumCalculator([mace_calc, d3_calc]) def mace_off( From a2e68ed3e66fbf36ace715825b49c27b002c26c9 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Sat, 21 Sep 2024 18:17:40 +0100 Subject: [PATCH 05/13] remove pylint too many positional argument --- .pre-commit-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ae4223d7..525edc7b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,6 +41,7 @@ repos: '--disable=missing-class-docstring', '--disable=missing-function-docstring', '--disable=too-many-arguments', + '--diable=too-many-positional-arguments', '--disable=too-many-locals', '--disable=not-callable', '--disable=logging-fstring-interpolation', From 7469e118f7d8c85bc0b6fd5f15990eaeeab74434 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Sat, 21 Sep 2024 18:21:39 +0100 Subject: [PATCH 06/13] fix the pre commit config --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 525edc7b..d78624bb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: '--disable=missing-class-docstring', '--disable=missing-function-docstring', '--disable=too-many-arguments', - '--diable=too-many-positional-arguments', + '--disable=too-many-positional-arguments', '--disable=too-many-locals', '--disable=not-callable', '--disable=logging-fstring-interpolation', From 73a3bf63f95ec90ac61c1c6c3baebe1c63341cc0 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Sat, 21 Sep 2024 19:53:24 +0100 Subject: [PATCH 07/13] fix test multi reference --- tests/test_foundations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_foundations.py b/tests/test_foundations.py index b1724629..fa35f8b9 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -158,6 +158,7 @@ def test_multi_reference(): forces_loaded = model_loaded(batch)["forces"] calc_foundation = mace_mp(device="cpu", default_dtype="float64") atoms = molecule("H2COH") + atoms.info["head"] = "MP2" atoms.calc = calc_foundation forces = atoms.get_forces() assert np.allclose( From 15134fc038c0ba44c458ce55917cd3f783e75165 Mon Sep 17 00:00:00 2001 From: CompRhys Date: Wed, 25 Sep 2024 13:43:28 -0400 Subject: [PATCH 08/13] fea: allow the model module to be pass rather than just the checkpoint path. --- mace/calculators/foundations_models.py | 4 +- mace/calculators/mace.py | 93 +++++++++++++++++--------- mace/cli/active_learning_md.py | 4 +- mace/tools/scripts_utils.py | 7 +- tests/test_calculator.py | 24 +++++-- tests/test_run_train.py | 12 ++-- 6 files changed, 96 insertions(+), 48 deletions(-) diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index 6d29cb04..96c5007b 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -227,4 +227,6 @@ def mace_anicc( print( "Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322" ) - return MACECalculator(model_path, device=device, default_dtype="float64") + return MACECalculator( + model_paths=model_path, device=device, default_dtype="float64" + ) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 292b114b..c74bfd24 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -18,7 +18,7 @@ from mace.modules.utils import extract_invariant from mace.tools import torch_geometric, torch_tools, utils from mace.tools.compile import prepare -from mace.tools.scripts_utils import extract_load +from mace.tools.scripts_utils import extract_model def get_model_dtype(model: torch.nn.Module) -> torch.dtype: @@ -49,8 +49,9 @@ class MACECalculator(Calculator): def __init__( self, - model_paths: Union[list, str], - device: str, + model_paths: Union[list, str] | None = None, + device: str | None = None, + models: Union[list[torch.nn.Module], torch.nn.Module] | None = None, energy_units_to_eV: float = 1.0, length_units_to_A: float = 1.0, default_dtype="", @@ -61,6 +62,21 @@ def __init__( **kwargs, ): Calculator.__init__(self, **kwargs) + + if "model_path" in kwargs: + if model_paths is None: + print("model_path argument deprecated, use model_paths") + model_paths = kwargs["model_path"] + else: + raise ValueError( + "both 'model_path' and 'model_paths' argument give, please only pass model_paths" + ) + + if (model_paths is None) == (models is None): + raise ValueError( + "Exactly one of 'model_paths' or 'models' must be provided" + ) + self.results = {} self.model_type = model_type @@ -89,53 +105,70 @@ def __init__( f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported" ) - if "model_path" in kwargs: - print("model_path argument deprecated, use model_paths") - model_paths = kwargs["model_path"] - - if isinstance(model_paths, str): - # Find all models that satisfy the wildcard (e.g. mace_model_*.pt) - model_paths_glob = glob(model_paths) - if len(model_paths_glob) == 0: - raise ValueError(f"Couldn't find MACE model files: {model_paths}") - model_paths = model_paths_glob - elif isinstance(model_paths, Path): - model_paths = [model_paths] - if len(model_paths) == 0: - raise ValueError("No mace file names supplied") - self.num_models = len(model_paths) - if len(model_paths) > 1: - print(f"Running committee mace with {len(model_paths)} models") + if model_paths is not None: + if isinstance(model_paths, str): + # Find all models that satisfy the wildcard (e.g. mace_model_*.pt) + model_paths_glob = glob(model_paths) + + if len(model_paths_glob) == 0: + raise ValueError(f"Couldn't find MACE model files: {model_paths}") + + model_paths = model_paths_glob + elif isinstance(model_paths, Path): + model_paths = [model_paths] + + if len(model_paths) == 0: + raise ValueError("No mace file names supplied") + self.num_models = len(model_paths) + + # Load models from files + self.models = [ + torch.load(f=model_path, map_location=device) + for model_path in model_paths + ] + + elif models is not None: + if not isinstance(models, list): + models = [models] + + if len(models) == 0: + raise ValueError("No models supplied") + + self.models = models + self.num_models = len(models) + + if self.num_models > 1: + print(f"Running committee mace with {self.num_models} models") + if model_type in ["MACE", "EnergyDipoleMACE"]: self.implemented_properties.extend( ["energies", "energy_var", "forces_comm", "stress_var"] ) elif model_type == "DipoleMACE": self.implemented_properties.extend(["dipole_var"]) + if compile_mode is not None: print(f"Torch compile is enabled with mode: {compile_mode}") self.models = [ torch.compile( - prepare(extract_load)(f=model_path, map_location=device), + prepare(extract_model)(model=model, map_location=device), mode=compile_mode, fullgraph=fullgraph, ) - for model_path in model_paths + for model in models ] self.use_compile = True else: - self.models = [ - torch.load(f=model_path, map_location=device) - for model_path in model_paths - ] self.use_compile = False + + # Ensure all models are on the same device for model in self.models: - model.to(device) # shouldn't be necessary but seems to help with GPU + model.to(device) + r_maxs = [model.r_max.cpu() for model in self.models] r_maxs = np.array(r_maxs) - assert np.all( - r_maxs == r_maxs[0] - ), "committee r_max are not all the same {' '.join(r_maxs)}" + if not np.all(r_maxs == r_maxs[0]): + raise ValueError(f"committee r_max are not all the same {' '.join(r_maxs)}") self.r_max = float(r_maxs[0]) self.device = torch_tools.init_device(device) diff --git a/mace/cli/active_learning_md.py b/mace/cli/active_learning_md.py index 648a30b2..a26be698 100644 --- a/mace/cli/active_learning_md.py +++ b/mace/cli/active_learning_md.py @@ -149,8 +149,8 @@ def run(args: argparse.Namespace) -> None: atoms_index = args.config_index mace_calc = MACECalculator( - mace_fname, - args.device, + model_paths=mace_fname, + device=args.device, default_dtype=args.default_dtype, ) diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index f44390a6..bcf223d4 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -218,10 +218,9 @@ def radial_to_transform(radial): def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: - model = torch.load(f=f, map_location=map_location) - model_copy = model.__class__(**extract_config_mace_model(model)) - model_copy.load_state_dict(model.state_dict()) - return model_copy.to(map_location) + return extract_model( + torch.load(f=f, map_location=map_location), map_location=map_location + ) def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module: diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 8ff87936..6590935c 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -111,7 +111,7 @@ def trained_model_fixture(tmp_path_factory, fitting_configs): assert p.returncode == 0 - return MACECalculator(tmp_path / "MACE.model", device="cpu") + return MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") @pytest.fixture(scope="module", name="trained_equivariant_model") @@ -174,7 +174,7 @@ def trained_model_equivariant_fixture(tmp_path_factory, fitting_configs): assert p.returncode == 0 - return MACECalculator(tmp_path / "MACE.model", device="cpu") + return MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") @pytest.fixture(scope="module", name="trained_dipole_model") @@ -239,7 +239,7 @@ def trained_dipole_fixture(tmp_path_factory, fitting_configs): assert p.returncode == 0 return MACECalculator( - tmp_path / "MACE.model", device="cpu", model_type="DipoleMACE" + model_paths=tmp_path / "MACE.model", device="cpu", model_type="DipoleMACE" ) @@ -305,7 +305,7 @@ def trained_energy_dipole_fixture(tmp_path_factory, fitting_configs): assert p.returncode == 0 return MACECalculator( - tmp_path / "MACE.model", device="cpu", model_type="EnergyDipoleMACE" + model_paths=tmp_path / "MACE.model", device="cpu", model_type="EnergyDipoleMACE" ) @@ -374,7 +374,7 @@ def trained_committee_fixture(tmp_path_factory, fitting_configs): _model_paths.append(tmp_path / f"MACE{seed}.model") - return MACECalculator(_model_paths, device="cpu") + return MACECalculator(model_paths=_model_paths, device="cpu") def test_calculator_node_energy(fitting_configs, trained_model): @@ -432,6 +432,20 @@ def test_calculator_committee(fitting_configs, trained_committee): assert forces_var.shape == at.calc.results["forces"].shape +def test_calculator_from_model(fitting_configs, trained_committee): + # test single model + test_calculator_forces( + fitting_configs, + trained_model=MACECalculator(models=trained_committee.models[0], device="cpu"), + ) + + # test committee model + test_calculator_committee( + fitting_configs, + trained_committee=MACECalculator(models=trained_committee.models, device="cpu"), + ) + + def test_calculator_dipole(fitting_configs, trained_dipole_model): at = fitting_configs[2].copy() at.calc = trained_dipole_model diff --git a/tests/test_run_train.py b/tests/test_run_train.py index fe6c8c46..015c65a7 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -100,7 +100,7 @@ def test_run_train(tmp_path, fitting_configs): p = subprocess.run(cmd.split(), env=run_env, check=True) assert p.returncode == 0 - calc = MACECalculator(tmp_path / "MACE.model", device="cpu") + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") Es = [] for at in fitting_configs: @@ -171,7 +171,7 @@ def test_run_train_missing_data(tmp_path, fitting_configs): p = subprocess.run(cmd.split(), env=run_env, check=True) assert p.returncode == 0 - calc = MACECalculator(tmp_path / "MACE.model", device="cpu") + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") Es = [] for at in fitting_configs: @@ -242,7 +242,7 @@ def test_run_train_no_stress(tmp_path, fitting_configs): p = subprocess.run(cmd.split(), env=run_env, check=True) assert p.returncode == 0 - calc = MACECalculator(tmp_path / "MACE.model", device="cpu") + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") Es = [] for at in fitting_configs: @@ -349,7 +349,7 @@ def test_run_train_multihead(tmp_path, fitting_configs): assert p.returncode == 0 calc = MACECalculator( - tmp_path / "MACE.model", device="cpu", default_dtype="float64" + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" ) Es = [] @@ -427,7 +427,7 @@ def test_run_train_foundation(tmp_path, fitting_configs): assert p.returncode == 0 calc = MACECalculator( - tmp_path / "MACE.model", device="cpu", default_dtype="float64" + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" ) Es = [] @@ -536,7 +536,7 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): assert p.returncode == 0 calc = MACECalculator( - tmp_path / "MACE.model", device="cpu", default_dtype="float64" + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" ) Es = [] From 6719849e754d81e715ab6712e008d6d637b938ac Mon Sep 17 00:00:00 2001 From: CompRhys Date: Wed, 25 Sep 2024 16:08:42 -0400 Subject: [PATCH 09/13] doc: improve the warning messages and fix typo --- mace/calculators/mace.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index c74bfd24..81a8c986 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -5,6 +5,7 @@ ########################################################################################### +import logging from glob import glob from pathlib import Path from typing import Union @@ -64,12 +65,15 @@ def __init__( Calculator.__init__(self, **kwargs) if "model_path" in kwargs: + deprecation_message = ( + "'model_path' argument is deprecated, please use 'model_paths'" + ) if model_paths is None: - print("model_path argument deprecated, use model_paths") + logging.warning(f"{deprecation_message} in the future.") model_paths = kwargs["model_path"] else: raise ValueError( - "both 'model_path' and 'model_paths' argument give, please only pass model_paths" + f"both 'model_path' and 'model_paths' given, {deprecation_message} only." ) if (model_paths is None) == (models is None): From e4cba377e6d7fc5cc9d721a84eacf30d236bed13 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 30 Sep 2024 19:17:32 +0100 Subject: [PATCH 10/13] extend read to any ase format Co-Authored-By: bernstei --- mace/cli/run_train.py | 17 +++++++++-------- mace/tools/scripts_utils.py | 12 ++++++++++++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 41f9f4a4..529454c2 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -49,6 +49,7 @@ get_swa, print_git_commit, setup_wandb, + check_path_ase_read, ) from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.utils import AtomicNumberTable @@ -188,10 +189,10 @@ def run(args: argparse.Namespace) -> None: ) # Data preparation - if head_config.train_file.endswith(".xyz"): + if check_path_ase_read(head_config.train_file): if head_config.valid_file is not None: - assert head_config.valid_file.endswith( - ".xyz" + assert check_path_ase_read( + head_config.valid_file ), "valid_file if given must be same format as train_file" config_type_weights = get_config_type_weights( head_config.config_type_weights @@ -221,7 +222,7 @@ def run(args: argparse.Namespace) -> None: ) head_configs.append(head_config) - if all(head_config.train_file.endswith(".xyz") for head_config in head_configs): + if all(check_path_ase_read(head_config.train_file) for head_config in head_configs): size_collections_train = sum( len(head_config.collections.train) for head_config in head_configs ) @@ -313,7 +314,7 @@ def run(args: argparse.Namespace) -> None: # yapf: disable for head_config in head_configs: if head_config.atomic_numbers is None: - assert head_config.train_file.endswith(".xyz"), "Must specify atomic_numbers when using .h5 train_file input" + assert check_path_ase_read(head_config.train_file), "Must specify atomic_numbers when using .h5 train_file input" z_table_head = tools.get_atomic_number_table_from_zs( z for configs in (head_config.collections.train, head_config.collections.valid) @@ -343,7 +344,7 @@ def run(args: argparse.Namespace) -> None: atomic_energies_dict = {} for head_config in head_configs: if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0: - if head_config.train_file.endswith(".xyz") and head_config.E0s.lower() != "foundation": + if check_path_ase_read(head_config.train_file) and head_config.E0s.lower() != "foundation": atomic_energies_dict[head_config.head_name] = get_atomic_energies( head_config.E0s, head_config.collections.train, head_config.z_table ) @@ -408,7 +409,7 @@ def run(args: argparse.Namespace) -> None: valid_sets = {head: [] for head in heads} train_sets = {head: [] for head in heads} for head_config in head_configs: - if head_config.train_file.endswith(".xyz"): + if check_path_ase_read(head_config.train_file): train_sets[head_config.head_name] = [ data.AtomicData.from_config( config, z_table=z_table, cutoff=args.r_max, heads=heads @@ -625,7 +626,7 @@ def run(args: argparse.Namespace) -> None: ) and head_configs[0].test_dir is not None: stop_first_test = True for head_config in head_configs: - if head_config.train_file.endswith(".xyz"): + if check_path_ase_read(head_config.train_file): print(head_config.test_file) for name, subset in head_config.collections.tests: print(name) diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index bcf223d4..fa310afe 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -10,6 +10,7 @@ import json import logging import os +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import numpy as np @@ -903,6 +904,17 @@ def check_folder_subfolder(folder_path): return False +def check_path_ase_read(filename: str) -> str: + filepath = Path(filename) + if filepath.is_dir(): + if len(list(filepath.glob("*.h5")) + list(filepath.glob("*.hdf5"))) == 0: + raise RuntimeError(f"Got directory {filename} with no .h5/.hdf5 files") + return False + if filepath.suffix in (".h5", ".hdf5"): + return False + return True + + def dict_to_namespace(dictionary): # Convert the dictionary into an argparse.Namespace namespace = argparse.Namespace() From 384baff4d0eb706ee073d4cf809a179c2cab6086 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 30 Sep 2024 19:21:55 +0100 Subject: [PATCH 11/13] sort import --- mace/cli/run_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 529454c2..cfbb0da7 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -34,6 +34,7 @@ ) from mace.tools.scripts_utils import ( LRScheduler, + check_path_ase_read, convert_to_json_format, create_error_table, dict_to_array, @@ -49,7 +50,6 @@ get_swa, print_git_commit, setup_wandb, - check_path_ase_read, ) from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.utils import AtomicNumberTable From e10030713440e525d7e74d17bef99b03c26aab30 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 1 Oct 2024 10:34:55 +0100 Subject: [PATCH 12/13] fix the preprocess data script --- mace/tools/arg_parser.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 046f04d6..11a6d2f3 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -735,6 +735,12 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser: default=None, required=False, ) + parser.add_argument( + "--work_dir", + help="set directory for all files and folders", + type=str, + default=".", + ) parser.add_argument( "--h5_prefix", help="Prefix for h5 files when saving", From 0ca9fedc1e58c45b1ff384e63b0e7eb5303c52bb Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 1 Oct 2024 18:12:07 +0100 Subject: [PATCH 13/13] reset e3nn version to 0.4.4 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 4d5419d0..13d55161 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,7 +15,7 @@ packages = find: python_requires = >=3.7 install_requires = torch>=1.12 - e3nn + e3nn==0.4.4 numpy<2.0 opt_einsum ase