Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
chiang-yuan committed Dec 22, 2024
1 parent 17bb201 commit a490e43
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 9 deletions.
11 changes: 6 additions & 5 deletions examples/mof/widom-insertion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,14 @@
}
],
"source": [
"from dask.distributed import Client\n",
"from dask_jobqueue import SLURMCluster\n",
"from prefect_dask import DaskTaskRunner\n",
"\n",
"from ase.build import molecule\n",
"from mlip_arena.models import MLIPEnum\n",
"from mlip_arena.tasks.mof.input import get_atoms_from_db\n",
"from mlip_arena.tasks.mof.flow import widom_insertion\n",
"from tqdm.auto import tqdm\n",
"from mlip_arena.tasks.mof.flow import run as MOF\n",
"\n",
"from prefect import flow\n",
"\n",
"@flow\n",
"def benchmark_test():\n",
Expand Down Expand Up @@ -174,7 +175,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
"version": "3.11.8"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
Expand Down
16 changes: 16 additions & 0 deletions mlip_arena/tasks/mof/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
import warnings

# Locate the LICENSE file
license_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "LICENSE")

if os.path.exists(license_path):
try:
with open(license_path, "r") as license_file:
license_content = license_file.read()
warnings.warn(
f"LICENSE content:\n{license_content}",
category=UserWarning
)
except Exception as e:
pass
26 changes: 22 additions & 4 deletions mlip_arena/tasks/mof/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

from collections import defaultdict
from pathlib import Path
from typing import IO, Any, Optional
from typing import IO, Any

import numpy as np
from prefect import task, flow
from prefect import flow, task
from prefect.cache_policies import INPUTS, TASK_SOURCE
from prefect.logging import get_run_logger
from prefect.runtime import task_run
Expand All @@ -23,6 +23,7 @@

from ase import Atoms, units
from ase.atoms import Atoms
from ase.build import molecule
from ase.filters import Filter
from ase.io.trajectory import Trajectory, TrajectoryWriter
from ase.optimize.optimize import Optimizer
Expand All @@ -31,6 +32,7 @@
from mlip_arena.tasks.utils import get_calculator

from .grid import get_accessible_positions
from .input import get_atoms_from_db


def add_molecule(gas: Atoms, rotate: bool = True, translate: tuple = None) -> Atoms:
Expand Down Expand Up @@ -121,8 +123,6 @@ def widom_insertion(
optimizer_kwargs: dict | None = None,
filter: Filter | str | None = "FrechetCell",
filter_kwargs: dict | None = None,
time_step: float | None = None, # fs
total_time: float = 1000, # fs
temperature: float = 300,
init_structure_optimize: bool = True,
init_gas_optimize: bool = True,
Expand Down Expand Up @@ -335,3 +335,21 @@ def widom_insertion(
results["heat_of_adsorption"].append(qst)
# self.log_results(results)
return results


@flow
def run(
db_path: Path | str = "mof.db",
):
states = []
for model in MLIPEnum:
for atoms in tqdm(get_atoms_from_db(db_path)):
state = widom_insertion.submit(
atoms,
molecule("CO2"),
calculator_name=model.name,
return_state=True
)
states.append(state)

return [s.result(raise_on_failture=False) for s in states if s.is_completed()]
25 changes: 25 additions & 0 deletions tests/test_mof.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
from mlip_arena.models import MLIPEnum
from mlip_arena.tasks.mof.flow import widom_insertion
from mlip_arena.tasks.mof.input import get_atoms_from_db
from prefect.testing.utilities import prefect_test_harness

from ase.build import molecule


@pytest.mark.parametrize("model", [MLIPEnum["MACE-MP(M)"]])
def test_widom_insertion(model: MLIPEnum):
atoms = get_atoms_from_db("mof.db")
with prefect_test_harness():
result = widom_insertion.with_options(
refresh_cache=True,
)(
structure=atoms,
gas=molecule("CO2"),
calculator_name=model.name,
num_insertions=10,
)
assert isinstance(result, dict)
assert isinstance(result["henry_coefficient"][0], float)
assert isinstance(result["averaged_interaction_energy"][0], float)
assert isinstance(result["heat_of_adsorption"][0], float)

0 comments on commit a490e43

Please sign in to comment.