Skip to content

Commit

Permalink
more logger
Browse files Browse the repository at this point in the history
  • Loading branch information
chiang-yuan committed Jan 12, 2025
1 parent cb1fb61 commit 89f2c18
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 37 deletions.
13 changes: 10 additions & 3 deletions mlip_arena/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

import torch

try:
from prefect.logging import get_run_logger

logger = get_run_logger()
except (ImportError, RuntimeError):
from loguru import logger


def get_freer_device() -> torch.device:
"""Get the GPU with the most free memory, or use MPS if available.
Expand All @@ -22,16 +29,16 @@ def get_freer_device() -> torch.device:
]
free_gpu_index = mem_free.index(max(mem_free))
device = torch.device(f"cuda:{free_gpu_index}")
print(
logger.info(
f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
)
elif torch.backends.mps.is_available():
# If no CUDA GPUs are available but MPS is, use MPS
print("No GPU available. Using MPS.")
logger.info("No GPU available. Using MPS.")
device = torch.device("mps")
else:
# Fallback to CPU if neither CUDA GPUs nor MPS are available
print("No GPU or MPS available. Using CPU.")
logger.info("No GPU or MPS available. Using CPU.")
device = torch.device("cpu")

return device
6 changes: 4 additions & 2 deletions mlip_arena/tasks/elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@
from prefect.states import State

from ase import Atoms
from ase.filters import * # type: ignore
from ase.optimize import * # type: ignore
from ase.optimize.optimize import Optimizer
from mlip_arena.models import MLIPEnum
from mlip_arena.tasks.optimize import run as OPT
Expand Down Expand Up @@ -81,6 +79,8 @@ def run(
atoms: Atoms,
calculator_name: str | MLIPEnum,
calculator_kwargs: dict | None = None,
dispersion: bool = False,
dispersion_kwargs: dict | None = None,
device: str | None = None,
optimizer: Optimizer | str = "BFGSLineSearch", # type: ignore
optimizer_kwargs: dict | None = None,
Expand Down Expand Up @@ -124,6 +124,8 @@ def run(
atoms=atoms,
calculator_name=calculator_name,
calculator_kwargs=calculator_kwargs,
dispersion=dispersion,
dispersion_kwargs=dispersion_kwargs,
device=device,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
Expand Down
27 changes: 10 additions & 17 deletions mlip_arena/tasks/neb.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
from typing import Any, Literal

from prefect import task
from prefect.cache_policies import INPUTS, TASK_SOURCE
Expand All @@ -54,13 +54,9 @@
from ase.utils.forcecurve import fit_images
from mlip_arena.models import MLIPEnum
from mlip_arena.tasks.optimize import run as OPT
from mlip_arena.tasks.utils import get_calculator
from mlip_arena.tasks.utils import get_calculator, logger, pformat
from pymatgen.io.ase import AseAtomsAdaptor


if TYPE_CHECKING:
pass

_valid_optimizers: dict[str, Optimizer] = {
"MDMin": MDMin,
"FIRE": FIRE,
Expand All @@ -86,7 +82,7 @@ def _generate_task_run_name():
atoms = parameters["start"]
else:
raise ValueError("No images or start atoms found in parameters")

calculator_name = parameters["calculator_name"]

return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"
Expand Down Expand Up @@ -156,20 +152,17 @@ def run(
criterion = criterion or {}

optimizer_instance = optimizer(neb, trajectory=traj_file, **optimizer_kwargs) # type: ignore

logger.info(f"Using optimizer: {optimizer_instance}")
logger.info(pformat(optimizer_kwargs))
logger.info(f"Criterion: {pformat(criterion)}")
optimizer_instance.run(**criterion)

neb_tool = NEBTools(neb.images)
barrier = neb_tool.get_barrier()

forcefit = fit_images(neb.images)

images = neb.images

return {
"barrier": barrier,
"images": images,
"forcefit": forcefit,
"barrier": neb_tool.get_barrier(),
"images": neb.images,
"forcefit": fit_images(neb.images),
}


Expand Down Expand Up @@ -261,7 +254,7 @@ def run_from_end_points(
)
)

images = [s.to_ase_atoms() for s in path]
images = [s.to_ase_atoms(msonable=False) for s in path]

return run.with_options(
refresh_cache=not cache_subtasks,
Expand Down
17 changes: 11 additions & 6 deletions mlip_arena/tasks/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from ase.optimize import * # type: ignore
from ase.optimize.optimize import Optimizer
from mlip_arena.models import MLIPEnum
from mlip_arena.tasks.utils import get_calculator
from mlip_arena.tasks.utils import get_calculator, logger, pformat


_valid_filters: dict[str, Filter] = {
"Filter": Filter,
Expand Down Expand Up @@ -94,16 +95,20 @@ def run(

if isinstance(filter, type) and issubclass(filter, Filter):
filter_instance = filter(atoms, **filter_kwargs)
print(f"Using filter: {filter_instance}")
logger.info(f"Using filter: {filter_instance}")
logger.info(pformat(filter_kwargs))

optimizer_instance = optimizer(atoms, **optimizer_kwargs)
print(f"Using optimizer: {optimizer_instance}")
optimizer_instance = optimizer(filter_instance, **optimizer_kwargs)
logger.info(f"Using optimizer: {optimizer_instance}")
logger.info(pformat(optimizer_kwargs))
logger.info(f"Criterion: {pformat(criterion)}")

optimizer_instance.run(**criterion)

elif filter is None:
optimizer_instance = optimizer(atoms, **optimizer_kwargs)
print(f"Using optimizer: {optimizer_instance}")
logger.info(f"Using optimizer: {optimizer_instance}")
logger.info(pformat(optimizer_kwargs))
logger.info(f"Criterion: {pformat(criterion)}")
optimizer_instance.run(**criterion)

return {
Expand Down
20 changes: 11 additions & 9 deletions mlip_arena/tasks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator

from ase import units
from ase.calculators.calculator import Calculator
from ase.calculators.calculator import Calculator, BaseCalculator
from ase.calculators.mixing import SumCalculator
from mlip_arena.models import MLIPEnum
from mlip_arena.models.utils import get_freer_device
Expand All @@ -21,7 +21,7 @@


def get_calculator(
calculator_name: str | MLIPEnum | Calculator,
calculator_name: str | MLIPEnum | Calculator | SumCalculator,
calculator_kwargs: dict | None,
dispersion: bool = False,
dispersion_kwargs: dict | None = None,
Expand All @@ -30,22 +30,24 @@ def get_calculator(
"""Get a calculator with optional dispersion correction."""
device = device or str(get_freer_device())

logger.info("Using device: %s", device)
logger.info(f"Using device: {device}")

calculator_kwargs = calculator_kwargs or {}

if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
assert issubclass(calculator_name.value, Calculator)
calc = calculator_name.value(**calculator_kwargs)
elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
elif isinstance(calculator_name, Calculator):
logger.warning("Using custom calculator: {calculator_name}")
elif isinstance(calculator_name, type) and issubclass(calculator_name, BaseCalculator):
logger.warning(f"Using custom calculator class: {calculator_name}")
calc = calculator_name(**calculator_kwargs)
elif isinstance(calculator_name, Calculator | SumCalculator):
logger.warning(f"Using custom calculator object (kwargs are ignored): {calculator_name}")
calc = calculator_name
else:
raise ValueError(f"Invalid calculator: {calculator_name}")

logger.info("Using calculator: %s", calc)
logger.info(f"Using calculator: {calc}")
if calculator_kwargs:
logger.info(pformat(calculator_kwargs))

Expand All @@ -61,9 +63,9 @@ def get_calculator(
)
calc = SumCalculator([calc, disp_calc])

logger.info("Using dispersion: %s", disp_calc)
logger.info(f"Using dispersion: {disp_calc}")
if dispersion_kwargs:
logger.info(pformat(dispersion_kwargs))

assert isinstance(calc, Calculator) or isinstance(calc, SumCalculator)
assert isinstance(calc, Calculator | SumCalculator)
return calc

0 comments on commit 89f2c18

Please sign in to comment.