diff --git a/mlff/cAPI/mlff_structure_relaxation.py b/mlff/cAPI/mlff_structure_relaxation.py index 6c581ea..45dbf3b 100644 --- a/mlff/cAPI/mlff_structure_relaxation.py +++ b/mlff/cAPI/mlff_structure_relaxation.py @@ -1,4 +1,5 @@ import argparse +import numpy as np import os import logging @@ -57,6 +58,8 @@ def run_relaxation(): parser.add_argument('--qn_tol', type=float, required=False, default=1e-4) parser.add_argument('--qn_max_steps', type=int, required=False, default=200) + parser.add_argument('--optimizer', type=str, required=False, default='QuasiNewton') + parser.add_argument('--mic', type=str, required=False, default=None, help='Minimal image convention.') @@ -221,12 +224,14 @@ def load_start_geometry(f: str) -> Atoms: # # scales = read_json(os.path.join(ckpt_dir, 'scales.json')) - potential = mdx.MLFFPotential.create_from_ckpt_dir(ckpt_dir=ckpt_dir, dtype=_mdx_dtype) - calc = mlffCalculator(potential=potential, - capacity_multiplier=1.25, - F_to_eV_Ang=default_access(conversion_table, key=F_key, default=eV), - E_to_eV=default_access(conversion_table, key=E_key, default=eV), - ) + calc = mlffCalculator.create_from_ckpt_dir( + ckpt_dir=ckpt_dir, + capacity_multiplier=1.25, + add_energy_shift=False, + F_to_eV_Ang=default_access(conversion_table, key=F_key, default=eV), + E_to_eV=default_access(conversion_table, key=E_key, default=eV), + dtype=np.float64, + ) molecule.set_calculator(calc) @@ -234,7 +239,7 @@ def load_start_geometry(f: str) -> Atoms: from ase.io import write write(os.path.join(save_dir, 'init_structure.xyz'), molecule) # do a geometry relaxation - qn = ase_opt.LBFGS(molecule) + qn = getattr(ase_opt, args.optimizer)(molecule) converged = qn.run(qn_tol, qn_max_steps) if converged: write(os.path.join(save_dir, 'relaxed_structure.xyz'), molecule) diff --git a/mlff/md/calculator.py b/mlff/md/calculator.py index 56a4345..a0637fe 100644 --- a/mlff/md/calculator.py +++ b/mlff/md/calculator.py @@ -37,7 +37,7 @@ def create_from_ckpt_dir(cls, F_to_eV_Ang: float = 1., capacity_multiplier: float = 1.25, add_energy_shift: bool = False, - dtype: np.dtype = np.float32): + dtype: np.dtype = np.float64): mlff_potential = MLFFPotential.create_from_ckpt_dir( ckpt_dir=ckpt_dir, @@ -60,7 +60,7 @@ def __init__( F_to_eV_Ang: float = 1., capacity_multiplier: float = 1.25, calculate_stress: bool = False, - dtype: np.dtype = np.float32, + dtype: np.dtype = np.float64, *args, **kwargs ): @@ -144,7 +144,7 @@ def calculate(self, atoms=None, *args, **kwargs): output = self.calculate_fn(System(R=R, Z=z, cell=cell), neighbors=neighbors) # note different cell convention - self.results = jax.tree_map(lambda x: np.array(x), output) + self.results = jax.tree_map(lambda x: np.array(x, dtype=self.dtype), output) def to_displacement(cell):