From fc4512147eba41899c42be1e126c747cfe49cd3d Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 16:28:02 +0100 Subject: [PATCH] move import in local scope where possible --- chiron/minimze.py | 16 ++++++++++------ chiron/neighbors.py | 4 ++-- chiron/potential.py | 17 ++++++++++------- chiron/states.py | 3 ++- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/chiron/minimze.py b/chiron/minimze.py index 0547e81..ea8fccf 100644 --- a/chiron/minimze.py +++ b/chiron/minimze.py @@ -1,16 +1,16 @@ -import jax -import jax.numpy as jnp -from jaxopt import GradientDescent -from loguru import logger as log +from typing import Callable +from jax import numpy as jnp -def minimize_energy(coordinates, potential_fn, nbr_list=None, maxiter=1000): +def minimize_energy( + coordinates: jnp.array, potential_fn: Callable, nbr_list=None, maxiter: int = 1000 +): """ Minimize the potential energy of a system using JAXopt. Parameters ---------- - coordinates : jnp.ndarray + coordinates : jnp.array The initial coordinates of the system. potential_fn : callable The potential energy function of the system, which takes coordinates as input. @@ -24,6 +24,7 @@ def minimize_energy(coordinates, potential_fn, nbr_list=None, maxiter=1000): jnp.ndarray The optimized coordinates. """ + from loguru import logger as log def objective_fn(x): if nbr_list is not None: @@ -33,6 +34,9 @@ def objective_fn(x): log.debug("Using NO neighbor list") return potential_fn(x) + from jaxopt import GradientDescent + import jax + optimizer = GradientDescent( fun=jax.value_and_grad(objective_fn), value_and_grad=True, maxiter=maxiter ) diff --git a/chiron/neighbors.py b/chiron/neighbors.py index 69a9502..9941e44 100644 --- a/chiron/neighbors.py +++ b/chiron/neighbors.py @@ -3,9 +3,8 @@ import jax import jax.numpy as jnp from functools import partial -from typing import Tuple, Optional, Union +from typing import Tuple, Union from .states import SamplerState -from loguru import logger as log from openmm import unit @@ -580,6 +579,7 @@ def build( ) self.neighbor_list = self.neighbor_list.reshape(-1, self.n_max_neighbors) + from loguru import logger as log while jnp.any(self.n_neighbors == self.n_max_neighbors).block_until_ready(): log.debug( diff --git a/chiron/potential.py b/chiron/potential.py index 14676dd..6d9b415 100644 --- a/chiron/potential.py +++ b/chiron/potential.py @@ -1,13 +1,13 @@ import jax import jax.numpy as jnp -from loguru import logger as log from openmm import unit from openmm.app import Topology -from typing import Optional class NeuralNetworkPotential: def __init__(self, model, **kwargs): + from loguru import logger as log + if model is None: log.warning("No model provided, using default model") else: @@ -166,6 +166,7 @@ def compute_energy(self, positions: jnp.array, nbr_list=None, debug_mode=False): """ # Compute the pair distances and displacement vectors + from loguru import logger as log if nbr_list is None: log.debug( @@ -319,11 +320,13 @@ def __init__( f"U0 must be a unit.Quantity with units of energy, U0.unit = {U0.unit}" ) - log.info("Initializing HarmonicOscillatorPotential") - log.info(f"k = {k}") - log.info(f"x0 = {x0}") - log.info(f"U0 = {U0}") - log.info("Energy is calculate: U(x) = (K/2) * ( (x-x0)^2 + y^2 + z^2 ) + U0") + from loguru import logger as log + + log.debug("Initializing HarmonicOscillatorPotential") + log.debug(f"k = {k}") + log.debug(f"x0 = {x0}") + log.debug(f"U0 = {U0}") + log.debug("Energy is calculate: U(x) = (K/2) * ( (x-x0)^2 + y^2 + z^2 ) + U0") self.k = jnp.array( k.value_in_unit_system(unit.md_unit_system) ) # spring constant diff --git a/chiron/states.py b/chiron/states.py index 8ba5e52..99ee8f9 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -1,7 +1,6 @@ from openmm import unit from typing import List, Optional, Union from jax import numpy as jnp -from loguru import logger as log from .potential import NeuralNetworkPotential @@ -204,6 +203,7 @@ def check_variables(self) -> None: def _check_completness(self): # check which variables are set set_variables = self.check_variables() + from loguru import logger as log if len(set_variables) == 0: log.info("No variables are set.") @@ -288,6 +288,7 @@ def calculate_reduced_potential_at_states( """ import numpy as np + from loguru import logger as log reduced_potentials = np.zeros(len(thermodynamic_states)) for state_idx, state in enumerate(thermodynamic_states):