Skip to content

Commit

Permalink
extend ArraysInterface so that simplish_leastsq in simplerlm.py never…
Browse files Browse the repository at this point in the history
… directly accesses elements of JTJ.
  • Loading branch information
rileyjmurray committed Nov 19, 2024
1 parent 6b0890c commit e72f449
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 58 deletions.
2 changes: 1 addition & 1 deletion pygsti/algorithms/gaugeopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def _call_jacobian_fn(gauge_group_el_vec):
assert(_call_jacobian_fn is not None), "Cannot use 'ls' method unless jacobian is available"
ralloc = _baseobjs.ResourceAllocation(comm) # FUTURE: plumb up a resource alloc object?
test_f = _call_objective_fn(x0)
solnX, converged, msg, _, _, _, _, _ = _opt.simplish_leastsq(
solnX, converged, msg, _, _, _, _ = _opt.simplish_leastsq(
_call_objective_fn, _call_jacobian_fn, x0, f_norm2_tol=tol,
jac_norm_tol=tol, rel_ftol=tol, rel_xtol=tol,
max_iter=maxiter, resource_alloc=ralloc,
Expand Down
28 changes: 28 additions & 0 deletions pygsti/optimize/arraysinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,19 @@ def jtj_diag_indices(self, jtj):
"""
return _np.diag_indices_from(jtj)

def jtj_update_regularization(self, jtj, prd, mu):
ind = self.jtj_diag_indices(jtj)
jtj[ind] = prd + mu
return

def jtj_pre_regularization_data(self, jtj):
return jtj[self.jtj_diag_indices(jtj)].copy()


def jtj_max_diagonal_element(self, jtj):
diag = jtj[self.jtj_diag_indices(jtj)]
return self.max_x(diag)


class DistributedArraysInterface(ArraysInterface):
"""
Expand Down Expand Up @@ -626,6 +639,9 @@ def allocate_jac(self):
"""
Allocate an array for holding a Jacobian matrix (type `'ep'`).
Note: this function is only called when the Jacobian needs to be
approximated with finite differences.
Returns
-------
numpy.ndarray or LocalNumpyArray
Expand Down Expand Up @@ -1266,3 +1282,15 @@ def jtj_diag_indices(self, jtj):
col_indices = _np.arange(global_param_indices.start, global_param_indices.stop)
assert(len(row_indices) == len(col_indices)) # checks that global_param_indices is good
return row_indices, col_indices # ~ _np.diag_indices_from(jtj)

def jtj_update_regularization(self, jtj, prd, mu):
ind = self.jtj_diag_indices(jtj)
jtj[ind] = prd + mu
return

def jtj_pre_regularization_data(self, jtj):
return jtj[self.jtj_diag_indices(jtj)].copy()

def jtj_max_diagonal_element(self, jtj):
diag = jtj[self.jtj_diag_indices(jtj)]
return self.max_x(diag)
4 changes: 2 additions & 2 deletions pygsti/optimize/customsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as _np
import scipy as _scipy

from pygsti.optimize.arraysinterface import UndistributedArraysInterface as _UndistributedArraysInterface
from pygsti.optimize.arraysinterface import DistributedArraysInterface as _DistributedArraysInterface
from pygsti.tools import sharedmemtools as _smt
from pygsti.tools import slicetools as _slct

Expand Down Expand Up @@ -90,7 +90,7 @@ def custom_solve(a, b, x, ari, resource_alloc, proc_threshold=100):
host_comm = resource_alloc.host_comm
ok_buf = _np.empty(1, _np.int64)

if comm is None or isinstance(ari, _UndistributedArraysInterface):
if comm is None or (not isinstance(ari, _DistributedArraysInterface)):
x[:] = _scipy.linalg.solve(a, b, assume_a='pos')
return

Expand Down
86 changes: 31 additions & 55 deletions pygsti/optimize/simplerlm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Custom implementation of the Levenberg-Marquardt Algorithm
Custom implementation of the Levenberg-Marquardt Algorithm (but simpler than customlm.py)
"""
#***************************************************************************************************
# Copyright 2015, 2019 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
Expand All @@ -22,6 +22,7 @@
from pygsti.baseobjs.verbosityprinter import VerbosityPrinter as _VerbosityPrinter
from pygsti.baseobjs.resourceallocation import ResourceAllocation as _ResourceAllocation
from pygsti.baseobjs.nicelyserializable import NicelySerializable as _NicelySerializable
from pygsti.objectivefns.objectivefns import Chi2Function, TimeIndependentMDCObjectiveFunction
from typing import Callable

#Make sure SIGINT will generate a KeyboardInterrupt (even if we're launched in the background)
Expand Down Expand Up @@ -238,7 +239,7 @@ def _from_nice_serialization(cls, state):
serial_solve_proc_threshold=state['serial_solve_number_of_processors_threshold'],
lsvec_mode=state.get('lsvec_mode', 'normal'))

def run(self, objective, profiler, printer):
def run(self, objective: TimeIndependentMDCObjectiveFunction, profiler, printer):

"""
Perform the optimization.
Expand Down Expand Up @@ -282,7 +283,7 @@ def run(self, objective, profiler, printer):
else:
ari = _ari.UndistributedArraysInterface(nEls, nP)

opt_x, converged, msg, mu, nu, norm_f, f, opt_jtj = simplish_leastsq(
opt_x, converged, msg, mu, nu, norm_f, f = simplish_leastsq(
objective_func, jacobian, x0,
max_iter=self.maxiter,
num_fd_iters=self.fditer,
Expand Down Expand Up @@ -324,9 +325,8 @@ def run(self, objective, profiler, printer):
unpenalized_f = f[0:-objective.ex] if (objective.ex > 0) else f
unpenalized_normf = sum(unpenalized_f**2) # objective function without penalty factors
chi2k_qty = objective.chi2k_distributed_qty(norm_f)

return OptimizerResult(objective, opt_x, norm_f, opt_jtj, unpenalized_normf, chi2k_qty,
{'msg': msg, 'mu': mu, 'nu': nu, 'fvec': f})
optimizer_specific_qtys = {'msg': msg, 'mu': mu, 'nu': nu, 'fvec': f}
return OptimizerResult(objective, opt_x, norm_f, None, unpenalized_normf, chi2k_qty, optimizer_specific_qtys)



Expand Down Expand Up @@ -366,11 +366,6 @@ def jac_guarded(k: int, num_fd_iters: int, obj_fn: Callable, jac_fn: Callable, f
fdJac_work[:, i - pslice.start] = fd
#if comm is not None: comm.barrier() # overkill for shared memory leader host barrier
Jac = fdJac_work
#DEBUG: compare with analytic jacobian (need to uncomment num_fd_iters DEBUG line above too)
#Jac_analytic = jac_fn(x)
#if _np.linalg.norm(Jac_analytic-Jac) > 1e-6:
# print("JACDIFF = ",_np.linalg.norm(Jac_analytic-Jac)," per el=",
# _np.linalg.norm(Jac_analytic-Jac)/Jac.size," sz=",Jac.size)
return Jac


Expand Down Expand Up @@ -506,7 +501,7 @@ def simplish_leastsq(
best_x = ari.allocate_jtf()
dx = ari.allocate_jtf()
new_x = ari.allocate_jtf()
jtj_buf = ari.allocate_jtj_shared_mem_buf()
optional_jtj_buff = ari.allocate_jtj_shared_mem_buf()
fdJac = ari.allocate_jac() if num_fd_iters > 0 else None

global_x = x0.copy()
Expand Down Expand Up @@ -537,9 +532,8 @@ def simplish_leastsq(
# ^ We have to set some *some* values in case we exit at the start of the first
# iteration. mu will almost certainly be overwritten before being read.
min_norm_f = 1e100 # sentinel
best_x_state = (mu, nu, norm_f, f.copy(), None)
best_x_state = (mu, nu, norm_f, f.copy())
# ^ here and elsewhere, need f.copy() b/c f is objfn mem
rawJTJ_scratch = None

try:

Expand All @@ -558,41 +552,34 @@ def simplish_leastsq(
printer.log(("** Converged with out-of-bounds with check interval=%d, reverting to last know in-bounds point and setting interval=1 **") % oob_check_interval, 2)
oob_check_interval = 1
x[:] = best_x[:]
mu, nu, norm_f, f[:], _ = best_x_state
continue # can't make use of saved JTJ yet - recompute on nxt iter
mu, nu, norm_f, f[:] = best_x_state
continue

if profiler: profiler.memory_check("simplish_leastsq: begin outer iter")

Jac = jac_guarded(k, num_fd_iters, obj_fn, jac_fn, f, ari, global_x, fdJac)


if profiler: profiler.memory_check("simplish_leastsq: after jacobian:"
+ "shape=%s, GB=%.2f" % (str(Jac.shape),
Jac.nbytes / (1024.0**3)))
if profiler:
jac_gb = Jac.nbytes/(1024.0**3) if hasattr(Jac, 'nbytes') else _np.NaN
vals = ((f.size, global_x.size), jac_gb)
profiler.memory_check("simplish_leastsq: after jacobian: shape=%s, GB=%.2f" % vals)

Jnorm = _np.sqrt(ari.norm2_jac(Jac))
xnorm = _np.sqrt(ari.norm2_x(x))
printer.log("--- Outer Iter %d: norm_f = %g, mu=%g, |x|=%g, |J|=%g" % (k, norm_f, mu, xnorm, Jnorm))

#assert(_np.isfinite(Jac).all()), "Non-finite Jacobian!" # NaNs tracking
#assert(_np.isfinite(_np.linalg.norm(Jac))), "Finite Jacobian has inf norm!" # NaNs tracking

tm = _time.time()

# Riley note: fill_JTJ is the first place where we try to access J as a dense matrix.
ari.fill_jtj(Jac, JTJ, jtj_buf)
ari.fill_jtj(Jac, JTJ, optional_jtj_buff)
ari.fill_jtf(Jac, f, minus_JTf) # 'P'-type
minus_JTf *= -1

if profiler: profiler.add_time("simplish_leastsq: dotprods", tm)
#assert(not _np.isnan(JTJ).any()), "NaN in JTJ!" # NaNs tracking
#assert(not _np.isinf(JTJ).any()), "inf in JTJ! norm Jac = %g" % _np.linalg.norm(Jac) # NaNs tracking
#assert(_np.isfinite(JTJ).all()), "Non-finite JTJ!" # NaNs tracking
#assert(_np.isfinite(minus_JTf).all()), "Non-finite minus_JTf!" # NaNs tracking

idiag = ari.jtj_diag_indices(JTJ)
norm_JTf = ari.infnorm_x(minus_JTf)
norm_x = ari.norm2_x(x)
undamped_JTJ_diag = JTJ[idiag].copy() # 'P'-type
pre_reg_data = ari.jtj_pre_regularization_data(JTJ)

if norm_JTf < jac_norm_tol:
if oob_check_interval <= 1:
Expand All @@ -603,27 +590,21 @@ def simplish_leastsq(
printer.log(("** Converged with out-of-bounds with check interval=%d, reverting to last know in-bounds point and setting interval=1 **") % oob_check_interval, 2)
oob_check_interval = 1
x[:] = best_x[:]
mu, nu, norm_f, f[:], _ = best_x_state
continue # can't make use of saved JTJ yet - recompute on nxt iter
mu, nu, norm_f, f[:] = best_x_state
continue

if k == 0:
mu, nu = (tau * ari.max_x(undamped_JTJ_diag), 2) if init_munu == 'auto' else init_munu
rawJTJ_scratch = JTJ.copy() # allocates the memory for a copy of JTJ so only update mem elsewhere
best_x_state = (mu, nu, norm_f, f.copy(), rawJTJ_scratch) # update mu,nu,JTJ of initial best state
elif _np.allclose(x, best_x):
# for iter k > 0, update JTJ of best_x_state if best_x == x (i.e., if we've just evaluated
# a previously accepted step that was deemed the best we've seen so far.)
rawJTJ_scratch[:, :] = JTJ[:, :] # use pre-allocated memory
rawJTJ_scratch[idiag] = undamped_JTJ_diag # no damping; the "raw" JTJ
best_x_state = best_x_state[0:4] + (rawJTJ_scratch,) # update mu,nu,JTJ of initial "best state"
max_jtj_diag = ari.jtj_max_diagonal_element(JTJ)
mu, nu = (tau * max_jtj_diag, 2) if init_munu == 'auto' else init_munu
best_x_state = (mu, nu, norm_f, f.copy())

#determing increment using adaptive damping
while True: # inner loop

if profiler: profiler.memory_check("simplish_leastsq: begin inner iter")

# ok if assume fine-param-proc.size == 1 (otherwise need to sync setting local JTJ)
JTJ[idiag] = undamped_JTJ_diag + mu # augment normal equations
ari.jtj_update_regularization(JTJ, pre_reg_data, mu)

#assert(_np.isfinite(JTJ).all()), "Non-finite JTJ (inner)!" # NaNs tracking
#assert(_np.isfinite(minus_JTf).all()), "Non-finite minus_JTf (inner)!" # NaNs tracking
Expand Down Expand Up @@ -676,7 +657,7 @@ def simplish_leastsq(
printer.log(("** Converged with out-of-bounds with check interval=%d, reverting to last know in-bounds point and setting interval=1 **") % oob_check_interval, 2)
oob_check_interval = 1
x[:] = best_x[:]
mu, nu, norm_f, f[:], _ = best_x_state
mu, nu, norm_f, f[:] = best_x_state
break
elif (norm_x + rel_xtol) < norm_dx * (_MACH_PRECISION**2):
msg = "(near-)singular linear system"
Expand Down Expand Up @@ -715,7 +696,7 @@ def simplish_leastsq(
printer.log(("** Hit out-of-bounds with check interval=%d, reverting to last know in-bounds point and setting interval=1 **") % oob_check_interval, 2)
oob_check_interval = 1
x[:] = best_x[:]
mu, nu, norm_f, f[:], _ = best_x_state # can't make use of saved JTJ yet
mu, nu, norm_f, f[:] = best_x_state
break # restart next outer loop
else:
raise ValueError("Invalid `oob_action`: '%s'" % oob_action)
Expand Down Expand Up @@ -750,7 +731,7 @@ def simplish_leastsq(
printer.log(("** Converged with out-of-bounds with check interval=%d, reverting to last know in-bounds point and setting interval=1 **") % oob_check_interval, 2)
oob_check_interval = 1
x[:] = best_x[:]
mu, nu, norm_f, f[:], _ = best_x_state # can't make use of saved JTJ yet
mu, nu, norm_f, f[:] = best_x_state
break

if (dL <= 0 or dF <= 0):
Expand Down Expand Up @@ -785,7 +766,7 @@ def simplish_leastsq(
printer.log(("** Hit out-of-bounds with check interval=%d, reverting to last know in-bounds point and setting interval=1 **") % oob_check_interval, 2)
oob_check_interval = 1
x[:] = best_x[:]
mu, nu, norm_f, f[:], _ = best_x_state # can't use of saved JTJ yet
mu, nu, norm_f, f[:] = best_x_state
break # restart next outer loop
else:
raise ValueError("Invalid `oob_action`: '%s'" % oob_action)
Expand All @@ -805,10 +786,7 @@ def simplish_leastsq(
if new_x_is_known_inbounds and norm_f < min_norm_f:
min_norm_f = norm_f
best_x[:] = x[:]
best_x_state = (mu, nu, norm_f, f.copy(), None)
#Note: we use rawJTJ=None above because the current `JTJ` was evaluated
# at the *last* x-value -- we need to wait for the next outer loop
# to compute the JTJ for this best_x_state
best_x_state = (mu, nu, norm_f, f.copy())

#assert(_np.isfinite(x).all()), "Non-finite x!" # NaNs tracking
#assert(_np.isfinite(f).all()), "Non-finite f!" # NaNs tracking
Expand Down Expand Up @@ -840,7 +818,7 @@ def simplish_leastsq(
ari.deallocate_jtj(JTJ)
ari.deallocate_jtf(minus_JTf)
ari.deallocate_jtf(x)
ari.deallocate_jtj_shared_mem_buf(jtj_buf)
ari.deallocate_jtj_shared_mem_buf(optional_jtj_buff)

if x_limits is not None:
ari.deallocate_jtf(x_lower_limits)
Expand All @@ -855,11 +833,9 @@ def simplish_leastsq(
ari.allgather_x(best_x, global_x)
ari.deallocate_jtf(best_x)

#JTJ[idiag] = undampled_JTJ_diag #restore diagonal
mu, nu, norm_f, f[:], rawJTJ = best_x_state
mu, nu, norm_f, f[:] = best_x_state

global_f = _np.empty(ari.global_num_elements(), 'd')
ari.allgather_f(f, global_f)

return global_x, converged, msg, mu, nu, norm_f, global_f, rawJTJ

return global_x, converged, msg, mu, nu, norm_f, global_f

0 comments on commit e72f449

Please sign in to comment.