Skip to content

Commit a2d0820

Browse files
committed
Added access to observable targets
1 parent 6438ed5 commit a2d0820

File tree

2 files changed

+52
-11
lines changed

2 files changed

+52
-11
lines changed

docs/p/notebooks/response_matrices.ipynb

+4-4
Original file line numberDiff line numberDiff line change
@@ -550,8 +550,8 @@
550550
"name": "stderr",
551551
"output_type": "stream",
552552
"text": [
553-
"/Users/laurent/dev/libraries/at/pyat/at/latticetools/response_matrix.py:584: AtWarning: No new excluded value\n",
554-
" warnings.warn(AtWarning(\"No new excluded value\"), stacklevel=1)\n"
553+
"/var/folders/9d/tcctx2j125xd3nzkr5bp1zq4000103/T/ipykernel_52637/1399543609.py:1: AtWarning: No new excluded value\n",
554+
" resp_dx.exclude_obs(obsid=0, refpts=\"BPM_07\")\n"
555555
]
556556
}
557557
],
@@ -591,8 +591,8 @@
591591
"name": "stderr",
592592
"output_type": "stream",
593593
"text": [
594-
"/Users/laurent/dev/libraries/at/pyat/at/latticetools/response_matrix.py:584: AtWarning: No new excluded value\n",
595-
" warnings.warn(AtWarning(\"No new excluded value\"), stacklevel=1)\n"
594+
"/var/folders/9d/tcctx2j125xd3nzkr5bp1zq4000103/T/ipykernel_52637/3319419291.py:1: AtWarning: No new excluded value\n",
595+
" resp_dx.exclude_obs(refpts=\"BPM_07\")\n"
596596
]
597597
}
598598
],

pyat/at/latticetools/response_matrix.py

+48-7
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@
143143
import concurrent.futures
144144
import abc
145145
import warnings
146-
from collections.abc import Sequence, Generator
146+
from collections.abc import Sequence, Generator, Callable
147147
from typing import Any, ClassVar
148148
from itertools import chain
149149
from functools import partial
@@ -451,29 +451,31 @@ def correct(
451451
ring: Lattice description. The response matrix observables
452452
will be evaluated for *ring* and the deviation from target will
453453
be corrected
454-
nvals: Desired number of singular values. If :py:obj:`None`,
455-
use all singular values
456454
apply: If :py:obj:`True`, apply the correction to *ring*
457455
niter: Number of iterations. For more than one iteration,
458456
*apply* must be :py:obj:`True`
457+
nvals: Desired number of singular values. If :py:obj:`None`,
458+
use all singular values. *nvals* may be a scalar or an iterable with
459+
*niter* values.
459460
460461
Returns:
461462
correction: Vector of correction values
462463
"""
463464
if niter > 1 and not apply:
464-
raise ValueError("needs: apply is True")
465+
raise ValueError("If niter > 1, 'apply' must be True")
465466
obs = self.observables
466467
if apply:
467468
self.variables.get(ring=ring, initial=True)
468469
sumcorr = np.array([0.0])
469-
for it in range(niter):
470+
for it, nv in zip(range(niter), np.broadcast_to(nvals, (niter,))):
471+
print(f'step {it+1}, nvals = {nv}')
470472
obs.evaluate(ring, **self.eval_args)
471473
err = obs.flat_deviations
472474
if np.any(np.isnan(err)):
473475
raise AtError(
474476
f"Step {it + 1}: Invalid observables, cannot compute correction"
475477
)
476-
corr = self.get_correction(obs.flat_deviations, nvals=nvals)
478+
corr = self.get_correction(obs.flat_deviations, nvals=nv)
477479
sumcorr = sumcorr + corr # non-broadcastable sumcorr
478480
if apply:
479481
self.variables.increment(corr, ring=ring)
@@ -547,6 +549,45 @@ def build_analytical(self) -> FloatArray:
547549
f"build_analytical not implemented for {self.__class__.__name__}"
548550
)
549551

552+
def _on_obs(self, fun: Callable, *args, obsid: int | str = 0):
553+
"""Apply a function to the selected observable"""
554+
if not isinstance(obsid, str):
555+
return fun(self.observables[obsid], *args)
556+
else:
557+
for obs in self.observables:
558+
if obs.name == obsid:
559+
return fun(obs, *args)
560+
else:
561+
raise ValueError(f"Observable {obsid} not found")
562+
563+
def get_target(self, *, obsid: int | str = 0) -> FloatArray:
564+
r"""Return the target of the specified observable
565+
566+
Args:
567+
obsid: :py:class:`.Observable` name or index in the observable list.
568+
569+
Returns:
570+
target: observable target
571+
"""
572+
def _get(obs):
573+
return obs.target
574+
575+
return self._on_obs(_get, obsid=obsid)
576+
577+
def set_target(self, target: npt.ArrayLike, *, obsid: int | str = 0) -> None:
578+
r"""Set the target of the specified observable
579+
580+
Args:
581+
target: observable target. Must be broadcastable to the shape of the
582+
observable value.
583+
obsid: :py:class:`.Observable` name or index in the observable list.
584+
"""
585+
586+
def _set(obs, targ):
587+
obs.target = targ
588+
589+
return self._on_obs(_set, target, obsid=obsid)
590+
550591
def exclude_obs(self, *, obsid: int | str = 0, refpts: Refpts = None) -> None:
551592
# noinspection PyUnresolvedReferences
552593
r"""Add an observable item to the set of excluded values
@@ -581,7 +622,7 @@ def exclude(ob, msk):
581622
else:
582623
msk[:] = False
583624
if np.all(msk == inimask):
584-
warnings.warn(AtWarning("No new excluded value"), stacklevel=1)
625+
warnings.warn(AtWarning("No new excluded value"), stacklevel=3)
585626
# Force a new computation
586627
self.singular_values = None
587628

0 commit comments

Comments
 (0)