|
143 | 143 | import concurrent.futures
|
144 | 144 | import abc
|
145 | 145 | import warnings
|
146 |
| -from collections.abc import Sequence, Generator |
| 146 | +from collections.abc import Sequence, Generator, Callable |
147 | 147 | from typing import Any, ClassVar
|
148 | 148 | from itertools import chain
|
149 | 149 | from functools import partial
|
@@ -451,29 +451,31 @@ def correct(
|
451 | 451 | ring: Lattice description. The response matrix observables
|
452 | 452 | will be evaluated for *ring* and the deviation from target will
|
453 | 453 | be corrected
|
454 |
| - nvals: Desired number of singular values. If :py:obj:`None`, |
455 |
| - use all singular values |
456 | 454 | apply: If :py:obj:`True`, apply the correction to *ring*
|
457 | 455 | niter: Number of iterations. For more than one iteration,
|
458 | 456 | *apply* must be :py:obj:`True`
|
| 457 | + nvals: Desired number of singular values. If :py:obj:`None`, |
| 458 | + use all singular values. *nvals* may be a scalar or an iterable with |
| 459 | + *niter* values. |
459 | 460 |
|
460 | 461 | Returns:
|
461 | 462 | correction: Vector of correction values
|
462 | 463 | """
|
463 | 464 | if niter > 1 and not apply:
|
464 |
| - raise ValueError("needs: apply is True") |
| 465 | + raise ValueError("If niter > 1, 'apply' must be True") |
465 | 466 | obs = self.observables
|
466 | 467 | if apply:
|
467 | 468 | self.variables.get(ring=ring, initial=True)
|
468 | 469 | sumcorr = np.array([0.0])
|
469 |
| - for it in range(niter): |
| 470 | + for it, nv in zip(range(niter), np.broadcast_to(nvals, (niter,))): |
| 471 | + print(f'step {it+1}, nvals = {nv}') |
470 | 472 | obs.evaluate(ring, **self.eval_args)
|
471 | 473 | err = obs.flat_deviations
|
472 | 474 | if np.any(np.isnan(err)):
|
473 | 475 | raise AtError(
|
474 | 476 | f"Step {it + 1}: Invalid observables, cannot compute correction"
|
475 | 477 | )
|
476 |
| - corr = self.get_correction(obs.flat_deviations, nvals=nvals) |
| 478 | + corr = self.get_correction(obs.flat_deviations, nvals=nv) |
477 | 479 | sumcorr = sumcorr + corr # non-broadcastable sumcorr
|
478 | 480 | if apply:
|
479 | 481 | self.variables.increment(corr, ring=ring)
|
@@ -547,6 +549,45 @@ def build_analytical(self) -> FloatArray:
|
547 | 549 | f"build_analytical not implemented for {self.__class__.__name__}"
|
548 | 550 | )
|
549 | 551 |
|
| 552 | + def _on_obs(self, fun: Callable, *args, obsid: int | str = 0): |
| 553 | + """Apply a function to the selected observable""" |
| 554 | + if not isinstance(obsid, str): |
| 555 | + return fun(self.observables[obsid], *args) |
| 556 | + else: |
| 557 | + for obs in self.observables: |
| 558 | + if obs.name == obsid: |
| 559 | + return fun(obs, *args) |
| 560 | + else: |
| 561 | + raise ValueError(f"Observable {obsid} not found") |
| 562 | + |
| 563 | + def get_target(self, *, obsid: int | str = 0) -> FloatArray: |
| 564 | + r"""Return the target of the specified observable |
| 565 | +
|
| 566 | + Args: |
| 567 | + obsid: :py:class:`.Observable` name or index in the observable list. |
| 568 | +
|
| 569 | + Returns: |
| 570 | + target: observable target |
| 571 | + """ |
| 572 | + def _get(obs): |
| 573 | + return obs.target |
| 574 | + |
| 575 | + return self._on_obs(_get, obsid=obsid) |
| 576 | + |
| 577 | + def set_target(self, target: npt.ArrayLike, *, obsid: int | str = 0) -> None: |
| 578 | + r"""Set the target of the specified observable |
| 579 | +
|
| 580 | + Args: |
| 581 | + target: observable target. Must be broadcastable to the shape of the |
| 582 | + observable value. |
| 583 | + obsid: :py:class:`.Observable` name or index in the observable list. |
| 584 | + """ |
| 585 | + |
| 586 | + def _set(obs, targ): |
| 587 | + obs.target = targ |
| 588 | + |
| 589 | + return self._on_obs(_set, target, obsid=obsid) |
| 590 | + |
550 | 591 | def exclude_obs(self, *, obsid: int | str = 0, refpts: Refpts = None) -> None:
|
551 | 592 | # noinspection PyUnresolvedReferences
|
552 | 593 | r"""Add an observable item to the set of excluded values
|
@@ -581,7 +622,7 @@ def exclude(ob, msk):
|
581 | 622 | else:
|
582 | 623 | msk[:] = False
|
583 | 624 | if np.all(msk == inimask):
|
584 |
| - warnings.warn(AtWarning("No new excluded value"), stacklevel=1) |
| 625 | + warnings.warn(AtWarning("No new excluded value"), stacklevel=3) |
585 | 626 | # Force a new computation
|
586 | 627 | self.singular_values = None
|
587 | 628 |
|
|
0 commit comments