-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclonablegenerator.py
326 lines (263 loc) · 10.3 KB
/
clonablegenerator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
from typing import Any, Callable, NamedTuple, Optional, TypeVar, Generic, Generator, Iterable, Iterator, cast
import abc
from dataclasses import dataclass
T_Yield = TypeVar("T_Yield")
T_Return = TypeVar("T_Return")
T_Send = TypeVar("T_Send")
T = TypeVar("T")
S = TypeVar("S")
class NIL:
pass
Nil = NIL()
@dataclass
class Yield(Generic[T_Yield]):
value: T_Yield
@dataclass
class Return(Generic[T_Return]):
value: T_Return
class Break:
pass
@dataclass
class ClonableGenerator(abc.ABC, Generic[T_Yield, T_Send, T_Return]):
step: int = 0
@abc.abstractmethod
def send(self, value: Optional[T_Send]) -> Yield[T_Yield] | Return[T_Return]: ...
def as_generator(self) -> Generator[T_Yield, T_Send, T_Return]:
# Note: the return value is NOT clonable.
# Used for compatibity when converting
v = None
while True:
y_or_r = self.send(v)
if isinstance(y_or_r, Yield):
yield y_or_r.value
else:
return y_or_r.value
@dataclass
class YieldFrom(Generic[T_Yield, T_Send, T_Return]):
gen: (
ClonableGenerator[T_Yield, T_Send, T_Return]
| Generator[T_Yield, T_Send, T_Return]
)
class ReturnYieldFrom(
Generic[T_Yield, T_Send, T_Return], YieldFrom[T_Yield, T_Send, T_Return]
):
pass
class LoopState:
end_step: Optional[int] = None
@dataclass
class ForLoopState(LoopState, Generic[T]):
base_it: Iterator[T]
start_value: T | NIL = Nil
@dataclass
class WhileLoopState(LoopState):
start_value: bool = False
class LoopStateContainer(NamedTuple):
by_step: dict[int, LoopState]
stack: list[LoopState]
class ClonableGeneratorImpl(ClonableGenerator[T_Yield, T_Send, T_Return]):
_yield_from: YieldFrom[T_Yield, T_Send, Any] | None
yield_from_result: Any
loop_state: LoopStateContainer
def __init__(self):
self._yield_from = None
self._next_auto_step = 0
@abc.abstractmethod
def send_impl(
self, value: Optional[T_Send]
) -> (
Yield[T_Yield]
| Return[T_Return]
| YieldFrom[T_Yield, T_Send, Any]
| Break
): ...
def send(self, value: Optional[T_Send]) -> Yield[T_Yield] | Return[T_Return]:
while True:
if self._yield_from is None:
# Auto step starts afresh each send
self._next_auto_step = 0
s = self.send_impl(value)
value = None
if isinstance(s, (Yield, Return)):
self.complete_step()
return s
if isinstance(s, Break):
self.complete_step()
# notify loop that it is done
self.loop_state.stack.pop(-1).end_step = self.step
continue
assert isinstance(s, YieldFrom)
self._yield_from = s
assert self._yield_from is not None
if isinstance(self._yield_from.gen, ClonableGenerator):
y_or_r = self._yield_from.gen.send(value)
if isinstance(y_or_r, Yield):
return y_or_r
self.yield_from_result = y_or_r.value
else:
try:
if value is None:
return Yield(next(self._yield_from.gen))
else:
return Yield(self._yield_from.gen.send(value))
except StopIteration as e:
self.yield_from_result = e.value
self.complete_step()
yf = self._yield_from
self._yield_from = None
value = None
if isinstance(yf, ReturnYieldFrom):
yfr = self.yield_from_result
self.yield_from_result = None
return Return(yfr)
def next_step(self, custom_step: Optional[int] = None) -> bool:
"Check step condition"
# next_auto_step gets updated everytime through send_impl.
# However, if these are always top level, next_auto_step will have the
# same value each time through send_impl. We can then check against
# step, which is preserved across send_impl to track state
s = self._next_auto_step
self._next_auto_step += 1
# if self.step == -1:
# self.step = s
return self.step == s
def skip_next_step(self):
self.step += 1
self.complete_step()
def complete_step(self):
self.step += 1
def for_loop(
self,
i: Iterable[T],
start_step: Optional[int] = None,
end_step: Optional[int] = None,
) -> Iterator[T]:
# this gets called once for each send_impl. It sets up its own
# iterator to keep track of state
# The inner iterater starts with returning whatever the base_it
# returned last. This way, each call through send_impl gets the
# same value. Only increments base_it if it is looped more than once
# in a single send_impl call
# loop_state isn't created unless it is needed
try:
loop_state = self.loop_state
except AttributeError:
loop_state = self.loop_state = LoopStateContainer({}, [])
# # auto_step logic
if start_step is None:
start_step = self._next_auto_step
self._next_auto_step += 1
else:
self._next_auto_step = start_step + 1
if self.step == start_step:
base_it = iter(i)
# We can't increment iterator here. We want to do that within the
# inner iterator, so StopIteration is handled corrctly
state = ForLoopState[T](base_it)
loop_state.by_step[start_step] = state
loop_state.stack.append(state)
else:
state = loop_state.by_step[start_step]
assert isinstance(state, ForLoopState)
state = cast(ForLoopState[T], state) # satisfy generic type check
if state.end_step is not None:
# loop complete, skip inner part of loop
# ensure steps after ase autonumbered consistently
if end_step is not None:
state.end_step = end_step
self._next_auto_step = state.end_step
return iter(())
outer = self
class inner_it(Iterator[S]):
def __init__(self):
self.first_run = True
def __next__(self) -> T:
assert start_step <= outer.step
if self.first_run:
# First run always returns the last value from base_it
self.first_run = False
if not isinstance(state.start_value, NIL):
return state.start_value
# Advance iterator and check for exit
try:
next_value = next(state.base_it)
except StopIteration:
# iterator done, break loop
outer.complete_step()
loop_state.stack.pop(-1)
# save end_step to ensure next passes preserve
if end_step is not None:
outer.step = end_step
state.end_step = outer._next_auto_step = outer.step
raise
# save across send_impl. This will be returned in the next send_impl
state.start_value = next_value
# the loop starts at the beginning
outer.step = outer._next_auto_step = start_step + 1
return next_value
def __iter__(self):
return self
return inner_it[T]()
def while_loop(
self,
condition: bool | Callable[[], bool],
start_step: Optional[int] = None,
end_step: Optional[int] = None,
):
# loop state isn't created unless it is needed
try:
loop_state = self.loop_state
except AttributeError:
loop_state = self.loop_state = LoopStateContainer({}, [])
# # auto_step logic
if start_step is None:
start_step = self._next_auto_step
self._next_auto_step += 1
else:
self._next_auto_step = start_step + 1
if self.step == start_step:
state = WhileLoopState()
loop_state.by_step[start_step] = state
loop_state.stack.append(state)
else:
state = loop_state.by_step[start_step]
assert isinstance(state, WhileLoopState)
if state.end_step is not None:
# Loop complete Skip inner part of loop
# ensure steps after ase autonumbered consistently
self._next_auto_step = state.end_step
return iter(())
outer = self
class inner_it(Iterator[None]):
def __init__(self):
self.first_run = True
def __next__(self) -> None:
assert start_step <= outer.step
if self.first_run:
# First run always returns the last value condition
self.first_run = False
if state.start_value:
# Condition True, continue loop
return None
# check condition
if isinstance(condition, Callable):
next_value = condition()
else:
next_value = condition
if not next_value:
# Condition false, break loop
outer.complete_step()
loop_state.stack.pop(-1)
# save end_step to ensure next passes preserve
if end_step is not None:
outer.step = end_step
state.end_step = outer._next_auto_step = outer.step
raise StopIteration
# save across send_impl. This will be returned in the next send_impl
state.start_value = next_value
outer.step = outer._next_auto_step = (
start_step + 1
) # the loop starts at the beginning
return None
def __iter__(self):
return self
return inner_it()