Skip to content

Commit

Permalink
Merge pull request #255 from PlasmaControl/derivs
Browse files Browse the repository at this point in the history
Allow alternate computation of derivatives
  • Loading branch information
f0uriest authored Sep 6, 2022
2 parents 6398d39 + 4093dec commit 4f2c9b4
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 42 deletions.
141 changes: 101 additions & 40 deletions desc/objectives/objective_funs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import scipy.linalg
from abc import ABC, abstractmethod
from inspect import getfullargspec

Expand All @@ -22,20 +23,29 @@ class ObjectiveFunction(IOAble):
Equilibrium that will be optimized to satisfy the objectives.
use_jit : bool, optional
Whether to just-in-time compile the objectives and derivatives.
deriv_mode : {"batched", "blocked"}
method for computing derivatives. "batched" is generally faster, "blocked" may
use less memory. Note that the "blocked" hessian will only be block diagonal.
verbose : int, optional
Level of output.
"""

_io_attrs_ = ["_objectives"]

def __init__(self, objectives, eq=None, use_jit=True, verbose=1):
def __init__(
self, objectives, eq=None, use_jit=True, deriv_mode="batched", verbose=1
):

if not isinstance(objectives, tuple):
objectives = (objectives,)

assert use_jit in {True, False}
assert deriv_mode in {"batched", "blocked"}

self._objectives = objectives
self._use_jit = use_jit
self._deriv_mode = deriv_mode
self._built = False
self._compiled = False

Expand Down Expand Up @@ -64,17 +74,62 @@ def _set_derivatives(self, use_jit=True):
Whether to just-in-time compile the objective and derivatives.
"""
self._grad = Derivative(self.compute_scalar, mode="grad", use_jit=use_jit)
self._hess = Derivative(
self.compute_scalar,
mode="hess",
use_jit=use_jit,
)
self._jac = Derivative(
self.compute,
mode="fwd",
use_jit=use_jit,
)

self._derivatives = {"jac": {}, "grad": {}, "hess": {}}
for arg in self.args:
self._derivatives["jac"][arg] = lambda x, arg=arg: jnp.vstack(
[
obj.derivatives["jac"][arg](
*self._kwargs_to_args(self.unpack_state(x), obj.args)
)
for obj in self.objectives
]
)
self._derivatives["grad"][arg] = lambda x, arg=arg: jnp.sum(
jnp.array(
[
obj.derivatives["grad"][arg](
*self._kwargs_to_args(self.unpack_state(x), obj.args)
)
for obj in self.objectives
]
),
axis=0,
)
self._derivatives["hess"][arg] = lambda x, arg=arg: jnp.sum(
jnp.array(
[
obj.derivatives["hess"][arg](
*self._kwargs_to_args(self.unpack_state(x), obj.args)
)
for obj in self.objectives
]
),
axis=0,
)

if self._deriv_mode == "blocked":
self._grad = lambda x: jnp.concatenate(
[jnp.atleast_1d(self._derivatives["grad"][arg](x)) for arg in self.args]
)
self._jac = lambda x: jnp.hstack(
[self._derivatives["jac"][arg](x) for arg in self.args]
)
self._hess = lambda x: scipy.linalg.block_diag(
*[self._derivatives["hess"][arg](x) for arg in self.args]
)
if self._deriv_mode == "batched":
self._grad = Derivative(self.compute_scalar, mode="grad", use_jit=use_jit)
self._hess = Derivative(
self.compute_scalar,
mode="hess",
use_jit=use_jit,
)
self._jac = Derivative(
self.compute,
mode="fwd",
use_jit=use_jit,
)

if use_jit:
self.compute = jit(self.compute)
Expand Down Expand Up @@ -137,7 +192,12 @@ def compute(self, x):
"""
kwargs = self.unpack_state(x)
f = jnp.concatenate([obj.compute(**kwargs) for obj in self.objectives])
f = jnp.concatenate(
[
obj.compute(*self._kwargs_to_args(kwargs, obj.args))
for obj in self.objectives
]
)
return f

def compute_scalar(self, x):
Expand Down Expand Up @@ -199,9 +259,13 @@ def unpack_state(self, x):

kwargs = {}
for arg in self.args:
kwargs[arg] = x[self.x_idx[arg]]
kwargs[arg] = jnp.atleast_1d(x[self.x_idx[arg]])
return kwargs

def _kwargs_to_args(self, kwargs, args):
tuple_args = (kwargs[arg] for arg in args)
return tuple_args

def x(self, eq):
"""Return the full state vector from the Equilibrium eq."""
x = np.zeros((self.dim_x,))
Expand All @@ -211,18 +275,15 @@ def x(self, eq):

def grad(self, x):
"""Compute gradient vector of scalar form of the objective wrt x."""
# TODO: add block method
return self._grad.compute(x)
return jnp.atleast_1d(self._grad(x).squeeze())

def hess(self, x):
"""Compute Hessian matrix of scalar form of the objective wrt x."""
# TODO: add block method
return self._hess.compute(x)
return jnp.atleast_2d(self._hess(x).squeeze())

def jac(self, x):
"""Compute Jacobian matrx of vector form of the objective wrt x."""
J = self._jac.compute(x)
return jnp.atleast_2d(J)
return jnp.atleast_2d(self._jac(x).squeeze())

def jvp(self, v, x):
"""Compute Jacobian-vector product of the objective function.
Expand Down Expand Up @@ -414,39 +475,39 @@ def _set_dimensions(self, eq):

def _set_derivatives(self, use_jit=True):
"""Set up derivatives of the objective wrt each argument."""
self._derivatives = {}
self._scalar_derivatives = {}
self._derivatives = {"jac": {}, "grad": {}, "hess": {}}
self._args = [arg for arg in getfullargspec(self.compute)[0] if arg != "self"]

# only used for linear objectives so variable values are irrelevant
kwargs = dict( # FIXME: need to use dim_x
[(arg, np.zeros((self.dimensions[arg],))) for arg in self.dimensions.keys()]
)
args = [kwargs[arg] for arg in self.args]

# constant derivatives are pre-computed, otherwise set up Derivative instance
for arg in arg_order:
if arg in self.args: # derivative wrt arg
self._derivatives[arg] = Derivative(
self._derivatives["jac"][arg] = Derivative(
self.compute,
argnum=self.args.index(arg),
mode="fwd",
use_jit=use_jit,
)
self._scalar_derivatives[arg] = Derivative(
self._derivatives["grad"][arg] = Derivative(
self.compute_scalar,
argnum=self.args.index(arg),
mode="fwd",
mode="grad",
use_jit=use_jit,
)
self._derivatives["hess"][arg] = Derivative(
self.compute_scalar,
argnum=self.args.index(arg),
mode="hess",
use_jit=use_jit,
)
if self.linear: # linear objectives have constant derivatives
self._derivatives[arg] = self._derivatives[arg].compute(*args)
self._scalar_derivatives[arg] = self._scalar_derivatives[
arg
].compute(*args)
else: # these derivatives are always zero
self._derivatives[arg] = np.zeros((self.dim_f, self.dimensions[arg]))
self._scalar_derivatives[arg] = np.zeros((1, self.dimensions[arg]))
self._derivatives["jac"][arg] = lambda *args, **kwargs: jnp.zeros(
(self.dim_f, self.dimensions[arg])
)
self._derivatives["grad"][arg] = lambda *args, **kwargs: jnp.zeros(
(1, self.dimensions[arg])
)
self._derivatives["hess"][arg] = lambda *args, **kwargs: jnp.zeros(
(self.dimensions[arg], self.dimensions[arg])
)

if use_jit:
self.compute = jit(self.compute)
Expand Down Expand Up @@ -491,7 +552,7 @@ def compute_scalar(self, *args, **kwargs):
f = self.compute(*args, **kwargs)
else:
f = jnp.sum(self.compute(*args, **kwargs) ** 2) / 2
return f
return f.squeeze()

def print_value(self, *args, **kwargs):
"""Print the value of the objective."""
Expand Down
2 changes: 1 addition & 1 deletion desc/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def factorize_linear_constraints(constraints, extra_args=[]):
xp = put(xp, x_idx[obj.target_arg], obj.target)
else:
unfixed_args.append(arg)
A_ = obj.derivatives[arg]
A_ = obj.derivatives["jac"][arg](jnp.zeros(obj.dimensions[arg]))
b_ = obj.target
if A_.shape[0]:
Ainv_, Z_ = svd_inv_null(A_)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import numpy as np
from desc.equilibrium import Equilibrium
from desc.objectives import (
ObjectiveFunction,
ForceBalance,
GenericObjective,
Energy,
Volume,
Expand Down Expand Up @@ -30,6 +32,8 @@ def test_volume(self):
obj = Volume(target=10 * np.pi ** 2, weight=1 / np.pi ** 2, eq=eq)
V = obj.compute(eq.R_lmn, eq.Z_lmn)
np.testing.assert_allclose(V, 10)
V_compute_scalar = obj.compute_scalar(eq.R_lmn, eq.Z_lmn)
np.testing.assert_allclose(V_compute_scalar, 10)

def test_aspect_ratio(self):
eq = Equilibrium()
Expand Down Expand Up @@ -82,3 +86,22 @@ def test_magnetic_well(self):
)
np.testing.assert_equal(len(magnetic_well), obj.grid.num_rho)
np.testing.assert_allclose(magnetic_well, 0, atol=1e-15)


def test_derivative_modes():
eq = Equilibrium(M=2, N=1, L=2)
obj1 = ObjectiveFunction(MagneticWell(), deriv_mode="batched", use_jit=False)
obj2 = ObjectiveFunction(MagneticWell(), deriv_mode="blocked", use_jit=False)

obj1.build(eq)
obj2.build(eq)
x = obj1.x(eq)
g1 = obj1.grad(x)
g2 = obj2.grad(x)
np.testing.assert_allclose(g1, g2, atol=1e-10)
J1 = obj1.jac(x)
J2 = obj2.jac(x)
np.testing.assert_allclose(J1, J2, atol=1e-10)
H1 = obj1.hess(x)
H2 = obj2.hess(x)
np.testing.assert_allclose(np.diag(H1), np.diag(H2), atol=1e-10)
2 changes: 1 addition & 1 deletion tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_plot_gradpsi(SOLOVEV):
return fig


@pytest.mark.mpl_image_compare(tolerance=50)
@pytest.mark.mpl_image_compare(tolerance=55)
def test_plot_normF_2d(SOLOVEV):
eq = EquilibriaFamily.load(load_from=str(SOLOVEV["desc_h5_path"]))[-1]
fig, ax = plot_2d(eq, "|F|", norm_F=True)
Expand Down

0 comments on commit 4f2c9b4

Please sign in to comment.