Skip to content

Commit

Permalink
move import in local scope where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Jan 9, 2024
1 parent 50d38df commit fc45121
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 16 deletions.
16 changes: 10 additions & 6 deletions chiron/minimze.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import jax
import jax.numpy as jnp
from jaxopt import GradientDescent
from loguru import logger as log
from typing import Callable
from jax import numpy as jnp


def minimize_energy(coordinates, potential_fn, nbr_list=None, maxiter=1000):
def minimize_energy(
coordinates: jnp.array, potential_fn: Callable, nbr_list=None, maxiter: int = 1000
):
"""
Minimize the potential energy of a system using JAXopt.
Parameters
----------
coordinates : jnp.ndarray
coordinates : jnp.array
The initial coordinates of the system.
potential_fn : callable
The potential energy function of the system, which takes coordinates as input.
Expand All @@ -24,6 +24,7 @@ def minimize_energy(coordinates, potential_fn, nbr_list=None, maxiter=1000):
jnp.ndarray
The optimized coordinates.
"""
from loguru import logger as log

def objective_fn(x):
if nbr_list is not None:
Expand All @@ -33,6 +34,9 @@ def objective_fn(x):
log.debug("Using NO neighbor list")
return potential_fn(x)

from jaxopt import GradientDescent
import jax

optimizer = GradientDescent(
fun=jax.value_and_grad(objective_fn), value_and_grad=True, maxiter=maxiter
)
Expand Down
4 changes: 2 additions & 2 deletions chiron/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import jax
import jax.numpy as jnp
from functools import partial
from typing import Tuple, Optional, Union
from typing import Tuple, Union
from .states import SamplerState
from loguru import logger as log
from openmm import unit


Expand Down Expand Up @@ -580,6 +579,7 @@ def build(
)

self.neighbor_list = self.neighbor_list.reshape(-1, self.n_max_neighbors)
from loguru import logger as log

while jnp.any(self.n_neighbors == self.n_max_neighbors).block_until_ready():
log.debug(
Expand Down
17 changes: 10 additions & 7 deletions chiron/potential.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import jax
import jax.numpy as jnp
from loguru import logger as log
from openmm import unit
from openmm.app import Topology
from typing import Optional


class NeuralNetworkPotential:
def __init__(self, model, **kwargs):
from loguru import logger as log

if model is None:
log.warning("No model provided, using default model")
else:
Expand Down Expand Up @@ -166,6 +166,7 @@ def compute_energy(self, positions: jnp.array, nbr_list=None, debug_mode=False):
"""
# Compute the pair distances and displacement vectors
from loguru import logger as log

if nbr_list is None:
log.debug(
Expand Down Expand Up @@ -319,11 +320,13 @@ def __init__(
f"U0 must be a unit.Quantity with units of energy, U0.unit = {U0.unit}"
)

log.info("Initializing HarmonicOscillatorPotential")
log.info(f"k = {k}")
log.info(f"x0 = {x0}")
log.info(f"U0 = {U0}")
log.info("Energy is calculate: U(x) = (K/2) * ( (x-x0)^2 + y^2 + z^2 ) + U0")
from loguru import logger as log

log.debug("Initializing HarmonicOscillatorPotential")
log.debug(f"k = {k}")
log.debug(f"x0 = {x0}")
log.debug(f"U0 = {U0}")
log.debug("Energy is calculate: U(x) = (K/2) * ( (x-x0)^2 + y^2 + z^2 ) + U0")
self.k = jnp.array(
k.value_in_unit_system(unit.md_unit_system)
) # spring constant
Expand Down
3 changes: 2 additions & 1 deletion chiron/states.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from openmm import unit
from typing import List, Optional, Union
from jax import numpy as jnp
from loguru import logger as log
from .potential import NeuralNetworkPotential


Expand Down Expand Up @@ -204,6 +203,7 @@ def check_variables(self) -> None:
def _check_completness(self):
# check which variables are set
set_variables = self.check_variables()
from loguru import logger as log

if len(set_variables) == 0:
log.info("No variables are set.")
Expand Down Expand Up @@ -288,6 +288,7 @@ def calculate_reduced_potential_at_states(
"""
import numpy as np
from loguru import logger as log

reduced_potentials = np.zeros(len(thermodynamic_states))
for state_idx, state in enumerate(thermodynamic_states):
Expand Down

0 comments on commit fc45121

Please sign in to comment.