Skip to content

Commit 6c3034b

Browse files
committed
handling of errors
1 parent edceae6 commit 6c3034b

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

pyat/at/latticetools/observablelist.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"ObservableList",
2727
]
2828

29-
from collections.abc import Iterable
29+
from collections.abc import Iterable, Iterator
3030
from functools import reduce
3131
from typing import Callable
3232

@@ -44,6 +44,17 @@ def _flatten(vals, order="F"):
4444
return np.concatenate([np.reshape(v, -1, order=order) for v in vals])
4545

4646

47+
class _ObsResIter(Iterator):
48+
def __init__(self, obsiter):
49+
self.base = obsiter
50+
51+
def __next__(self):
52+
val = next(self.base)
53+
if isinstance(val, Exception):
54+
raise val
55+
return val
56+
57+
4758
class _ObsResults(tuple):
4859
def __getitem__(self, item):
4960
if isinstance(item, slice):
@@ -54,6 +65,9 @@ def __getitem__(self, item):
5465
raise AtError(f"Evaluation failed: {val.args[0]}") from val
5566
return val
5667

68+
def __iter__(self):
69+
return _ObsResIter(super().__iter__())
70+
5771

5872
class ObservableList(list):
5973
"""Handles a list of Observables to be evaluated together.
@@ -265,7 +279,7 @@ def obseval(ring, obs):
265279
"""Evaluate a single observable."""
266280

267281
def check_error(data, refpts):
268-
return data if isinstance(data, AtError) else data[refpts]
282+
return data if isinstance(data, Exception) else data[refpts]
269283

270284
obsneeds = obs.needs
271285
obsrefs = getattr(obs, "_boolrefs", None)
@@ -385,7 +399,7 @@ def ringeval(
385399
trajs, orbits, rgdata, eldata, emdata, mxdata, geodata = ringeval(
386400
ring, dp=dp, dct=dct, df=df
387401
)
388-
return [obseval(ring, ob) for ob in self]
402+
return _ObsResults(obseval(ring, ob) for ob in self)
389403

390404
def check(self) -> bool:
391405
"""Check the evaluation

pyat/at/latticetools/observables.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,12 @@ def evaluate(self, *data, initial: bool = False):
330330
"""
331331
for d in data:
332332
if isinstance(d, Exception):
333-
self._value = d
333+
errtype = type(d)
334+
err = errtype(f"Evaluation of {self.name} failed: {d.args[0]}")
335+
err.__cause__ = d
336+
self._value = err
334337
self._shape = None
335-
return d
338+
return err
336339

337340
val = np.asarray(self.fun(*data, *self.args, **self.kwargs))
338341
if initial:
@@ -357,7 +360,7 @@ def value(self):
357360
"""Value of the observable."""
358361
val = self._value
359362
if isinstance(val, Exception):
360-
raise AtError(f"Evaluation of {self.name} failed: {val.args[0]}") from val
363+
raise val
361364
return val
362365

363366
@property

0 commit comments

Comments
 (0)