Skip to content

Commit 8bd436d

Browse files
committed
Documentation
1 parent 2ca5935 commit 8bd436d

File tree

3 files changed

+43
-38
lines changed

3 files changed

+43
-38
lines changed

pyat/at/latticetools/observables.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,9 @@ def evaluate(self, *data, initial: bool = False):
312312
sent to the evaluation function
313313
initial: It :py:obj:`None`, store the result as the initial
314314
value
315+
316+
Returns:
317+
value: The value of the observable.
315318
"""
316319
for d in data:
317320
if isinstance(d, Exception):
@@ -847,7 +850,9 @@ def __init__(
847850
848851
Example:
849852
850-
>>> obs = LatticeObservable(at.Sextupole, "KickAngle", index=0, statfun=np.sum)
853+
>>> obs = LatticeObservable(
854+
... at.Sextupole, "KickAngle", index=0, statfun=np.sum
855+
... )
851856
852857
Observe the sum of horizontal kicks in Sextupoles
853858
"""

pyat/at/latticetools/response_matrix.py

+33-34
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
... )
3131
3232
The variables are the horizontal displacement ``dx`` of all quadrupoles. The variable
33-
name is set to *dx_nnnn* where *nnnn* is the index of the quadruple in the lattice.
33+
name is set to *dx_nnnn* where *nnnn* is the index of the quadrupole in the lattice.
3434
The step is set to 0.0001 m.
3535
3636
Let's take the horizontal positions at all beam position monitors as observables:
@@ -228,8 +228,9 @@ class _SvdSolver(abc.ABC):
228228
_obsmask: npt.NDArray[bool]
229229
_varmask: npt.NDArray[bool]
230230
_response: FloatArray | None = None
231-
v: FloatArray | None = None
232-
uh: FloatArray | None = None
231+
_v: FloatArray | None = None
232+
_uh: FloatArray | None = None
233+
#: Singular values of the response matrix
233234
singular_values: FloatArray | None = None
234235

235236
def __init__(self, nobs: int, nvar: int):
@@ -240,15 +241,15 @@ def __init__(self, nobs: int, nvar: int):
240241
def reset_vars(self):
241242
"""Reset the variable exclusion mask: enable all variables"""
242243
self._varmask = np.ones(self.shape[1], dtype=bool)
243-
self.v = None
244-
self.uh = None
244+
self._v = None
245+
self._uh = None
245246
self.singular_values = None
246247

247248
def reset_obs(self):
248249
"""Reset the observable exclusion mask: enable all observables"""
249250
self._obsmask = np.ones(self.shape[0], dtype=bool)
250-
self.v = None
251-
self.uh = None
251+
self._v = None
252+
self._uh = None
252253
self.singular_values = None
253254

254255
@property
@@ -269,11 +270,11 @@ def solve(self) -> None:
269270
resp = self.weighted_response
270271
selected = np.ix_(self._obsmask, self._varmask)
271272
u, s, vh = np.linalg.svd(resp[selected], full_matrices=False)
272-
self.v = vh.T * (1 / s) * self.varweights[self._varmask].reshape(-1, 1)
273-
self.uh = u.T / self.obsweights[self._obsmask]
273+
self._v = vh.T * (1 / s) * self.varweights[self._varmask].reshape(-1, 1)
274+
self._uh = u.T / self.obsweights[self._obsmask]
274275
self.singular_values = s
275276

276-
def check_norm(self) -> tuple[np.ndarray, np.ndarray]:
277+
def check_norm(self) -> tuple[FloatArray, FloatArray]:
277278
"""Display the norm of the rows and columns of the weighted response matrix.
278279
279280
Adjusting the variables and observable weights to equalize the norms
@@ -309,7 +310,7 @@ def response(self, response: FloatArray) -> None:
309310
self._response = response
310311

311312
@property
312-
def weighted_response(self) -> np.ndarray:
313+
def weighted_response(self) -> FloatArray:
313314
"""Weighted response matrix."""
314315
return self.response * (self.varweights / self.obsweights.reshape(-1, 1))
315316

@@ -329,11 +330,11 @@ def correction_matrix(self, nvals: int | None = None) -> FloatArray:
329330
nvals = len(self.singular_values)
330331
cormat = np.zeros(self._shape[::-1])
331332
selected = np.ix_(self._varmask, self._obsmask)
332-
cormat[selected] = self.v[:, :nvals] @ self.uh[:nvals, :]
333+
cormat[selected] = self._v[:, :nvals] @ self._uh[:nvals, :]
333334
return cormat
334335

335336
def get_correction(
336-
self, observed: np.ndarray, nvals: int | None = None
337+
self, observed: FloatArray, nvals: int | None = None
337338
) -> FloatArray:
338339
"""Compute the correction of the given observation.
339340
@@ -384,7 +385,7 @@ class ResponseMatrix(_SvdSolver):
384385
ring: Lattice
385386
variables: VariableList #: List of matrix :py:class:`Variable <.VariableBase>`\ s
386387
observables: ObservableList #: List of matrix :py:class:`.Observable`\s
387-
eval_args: dict[str, Any] = {}
388+
_eval_args: dict[str, Any] = {}
388389

389390
def __init__(
390391
self,
@@ -433,7 +434,7 @@ def __str__(self):
433434
return f"{type(self).__name__}({no} observables, {nv} variables)"
434435

435436
@property
436-
def varweights(self) -> np.ndarray:
437+
def varweights(self):
437438
"""Variable weights."""
438439
return self.variables.deltas
439440

@@ -444,7 +445,7 @@ def obsweights(self) -> np.ndarray:
444445

445446
def correct(
446447
self, ring: Lattice, nvals: int = None, niter: int = 1, apply: bool = False
447-
) -> np.ndarray:
448+
) -> FloatArray:
448449
"""Compute and optionally apply the correction.
449450
450451
Args:
@@ -469,7 +470,7 @@ def correct(
469470
sumcorr = np.array([0.0])
470471
for it, nv in zip(range(niter), np.broadcast_to(nvals, (niter,))):
471472
print(f'step {it+1}, nvals = {nv}')
472-
obs.evaluate(ring, **self.eval_args)
473+
obs.evaluate(ring, **self._eval_args)
473474
err = obs.flat_deviations
474475
if np.any(np.isnan(err)):
475476
raise AtError(
@@ -512,7 +513,7 @@ def build_tracking(
512513
Returns:
513514
response: Response matrix
514515
"""
515-
self.eval_args = kwargs
516+
self._eval_args = kwargs
516517
self.observables.evaluate(self.ring)
517518
ring = self.ring.deepcopy()
518519

@@ -720,8 +721,8 @@ class OrbitResponseMatrix(ResponseMatrix):
720721
... )
721722
"""
722723

723-
bpmrefs: Uint32Refpts
724-
steerrefs: Uint32Refpts
724+
bpmrefs: Uint32Refpts #: location of position monitors
725+
steerrefs: Uint32Refpts #: location of steerers
725726

726727
def __init__(
727728
self,
@@ -731,7 +732,7 @@ def __init__(
731732
steerrefs: Refpts = _orbit_correctors,
732733
*,
733734
cavrefs: Refpts = None,
734-
bpmweight: float = 1.0,
735+
bpmweight: float | Sequence[float] = 1.0,
735736
bpmtarget: float | Sequence[float] = 0.0,
736737
steerdelta: float | Sequence[float] = 0.0001,
737738
cavdelta: float | None = None,
@@ -761,8 +762,6 @@ def __init__(
761762
is also the cavity weight. Default: automatically computed.
762763
steerdelta: Step on steerers for matrix computation [rad]. This is
763764
also the steerer weight. Must be broadcastable to the number of steerers.
764-
cavdelta: Step on RF frequency for matrix computation [Hz]. This
765-
is also the cavity weight
766765
steersum: If :py:obj:`True`, the sum of steerers is appended to the
767766
Observables.
768767
stsumweight: Weight on steerer summation. Default: automatically computed.
@@ -980,39 +979,39 @@ def tauwj(muj, muw):
980979
return resp
981980

982981
@property
983-
def bpmweight(self) -> np.ndarray:
982+
def bpmweight(self) -> FloatArray:
984983
"""Weight of position readings."""
985984
return self.observables[0].weight
986985

987986
@bpmweight.setter
988-
def bpmweight(self, value):
987+
def bpmweight(self, value: npt.ArrayLike):
989988
self.observables[0].weight = value
990989

991990
@property
992-
def stsumweight(self) -> np.ndarray:
991+
def stsumweight(self) -> FloatArray:
993992
"""Weight of steerer summation."""
994993
return self.observables[1].weight
995994

996995
@stsumweight.setter
997-
def stsumweight(self, value):
996+
def stsumweight(self, value: float):
998997
self.observables[1].weight = value
999998

1000999
@property
1001-
def steerdelta(self) -> np.ndarray:
1000+
def steerdelta(self) -> FloatArray:
10021001
"""Step and weight of steerers."""
10031002
return self.variables[: self.nbsteers].deltas
10041003

10051004
@steerdelta.setter
1006-
def steerdelta(self, value):
1005+
def steerdelta(self, value: npt.ArrayLike):
10071006
self.variables[: self.nbsteers].deltas = value
10081007

10091008
@property
1010-
def cavdelta(self) -> np.ndarray:
1009+
def cavdelta(self) -> FloatArray:
10111010
"""Step and weight of RF frequency deviation."""
10121011
return self.variables[self.nbsteers].delta
10131012

10141013
@cavdelta.setter
1015-
def cavdelta(self, value):
1014+
def cavdelta(self, value: float):
10161015
self.variables[self.nbsteers].delta = value
10171016

10181017

@@ -1111,7 +1110,7 @@ def build_analytical(self, **kwargs) -> FloatArray:
11111110
"""
11121111
ring = self.ring
11131112
pl = self.plane
1114-
twiss_in = self.eval_args.get("twiss_in", self._default_twiss_in)
1113+
twiss_in = self._eval_args.get("twiss_in", self._default_twiss_in)
11151114
_, _, elemdata = ring.linopt6(All, twiss_in=twiss_in, **kwargs)
11161115
dataj = elemdata[self.bpmrefs]
11171116
dataw = elemdata[self.steerrefs]
@@ -1178,12 +1177,12 @@ def exclude_vars(self, *varid: int | str, refpts: Refpts = None) -> None:
11781177
super().exclude_vars(*varid, *names)
11791178

11801179
@property
1181-
def bpmweight(self) -> np.ndarray:
1180+
def bpmweight(self) -> FloatArray:
11821181
"""Weight of position readings."""
11831182
return self.observables[0].weight
11841183

11851184
@bpmweight.setter
1186-
def bpmweight(self, value):
1185+
def bpmweight(self, value: npt.ArrayLike):
11871186
self.observables[0].weight = value
11881187

11891188
@property

pyat/at/plot/response_matrix.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@ def plot_obs_analysis(
6868
if resp.singular_values is None:
6969
resp.solve()
7070
obs = resp.observables
71-
obs.evaluate(lattice, **resp.eval_args)
72-
corr = resp.uh @ obs.flat_deviations
71+
# noinspection PyProtectedMember
72+
obs.evaluate(lattice, **resp._eval_args)
73+
corr = resp._uh @ obs.flat_deviations
7374
if ax is None:
7475
fig, ax = plt.subplots()
7576
ax.bar(range(len(corr)), corr)
@@ -98,7 +99,7 @@ def plot_var_analysis(
9899
var = resp.variables
99100
if ax is None:
100101
fig, ax = plt.subplots()
101-
corr = (resp.v * resp.singular_values).T @ var.get(lattice)
102+
corr = (resp._v * resp.singular_values).T @ var.get(lattice)
102103
ax.bar(range(len(corr)), corr)
103104
if logscale:
104105
ax.set_yscale("log")

0 commit comments

Comments
 (0)