26
26
"ObservableList" ,
27
27
]
28
28
29
- from collections .abc import Iterable
29
+ from collections .abc import Iterable , Iterator
30
30
from functools import reduce
31
31
from typing import Callable
32
32
@@ -44,6 +44,17 @@ def _flatten(vals, order="F"):
44
44
return np .concatenate ([np .reshape (v , - 1 , order = order ) for v in vals ])
45
45
46
46
47
+ class _ObsResIter (Iterator ):
48
+ def __init__ (self , obsiter ):
49
+ self .base = obsiter
50
+
51
+ def __next__ (self ):
52
+ val = next (self .base )
53
+ if isinstance (val , Exception ):
54
+ raise val
55
+ return val
56
+
57
+
47
58
class _ObsResults (tuple ):
48
59
def __getitem__ (self , item ):
49
60
if isinstance (item , slice ):
@@ -54,6 +65,9 @@ def __getitem__(self, item):
54
65
raise AtError (f"Evaluation failed: { val .args [0 ]} " ) from val
55
66
return val
56
67
68
+ def __iter__ (self ):
69
+ return _ObsResIter (super ().__iter__ ())
70
+
57
71
58
72
class ObservableList (list ):
59
73
"""Handles a list of Observables to be evaluated together.
@@ -265,7 +279,7 @@ def obseval(ring, obs):
265
279
"""Evaluate a single observable."""
266
280
267
281
def check_error (data , refpts ):
268
- return data if isinstance (data , AtError ) else data [refpts ]
282
+ return data if isinstance (data , Exception ) else data [refpts ]
269
283
270
284
obsneeds = obs .needs
271
285
obsrefs = getattr (obs , "_boolrefs" , None )
@@ -385,7 +399,7 @@ def ringeval(
385
399
trajs , orbits , rgdata , eldata , emdata , mxdata , geodata = ringeval (
386
400
ring , dp = dp , dct = dct , df = df
387
401
)
388
- return [ obseval (ring , ob ) for ob in self ]
402
+ return _ObsResults ( obseval (ring , ob ) for ob in self )
389
403
390
404
def check (self ) -> bool :
391
405
"""Check the evaluation
0 commit comments