diff --git a/qupsy/language.py b/qupsy/language.py index bbebd49..149da6a 100644 --- a/qupsy/language.py +++ b/qupsy/language.py @@ -34,9 +34,9 @@ def depth(self) -> int: ) @property - def filled(self) -> bool: + def terminated(self) -> bool: for aexp in self.children: - if not aexp.filled: + if not aexp.terminated: return False return True @@ -104,7 +104,7 @@ def children(self) -> list[Aexp]: return [] @property - def filled(self) -> bool: + def terminated(self) -> bool: return False @@ -227,9 +227,9 @@ def depth(self) -> int: ) @property - def filled(self) -> bool: + def terminated(self) -> bool: for aexp in self.children: - if not aexp.filled: + if not aexp.terminated: return False return True @@ -253,7 +253,7 @@ def children(self) -> list[Aexp]: return [] @property - def filled(self) -> bool: + def terminated(self) -> bool: return False @@ -410,9 +410,9 @@ def depth(self) -> int: ) @property - def filled(self) -> bool: + def terminated(self) -> bool: for child in self.children: - if not child.filled: + if not child.terminated: return False return True @@ -436,7 +436,7 @@ def children(self) -> list[Cmd | Gate | Aexp]: return [] @property - def filled(self) -> bool: + def terminated(self) -> bool: return False diff --git a/qupsy/transition.py b/qupsy/transition.py index 52024de..35e3759 100644 --- a/qupsy/transition.py +++ b/qupsy/transition.py @@ -36,7 +36,7 @@ def visit_Aexp(self, aexp: Aexp) -> list[Aexp]: return visitor(aexp) for i, child in enumerate(aexp.children): - if child.filled: + if child.terminated: continue next_children = self.visit_Aexp(child) pre_args = aexp.children[:i] @@ -50,7 +50,7 @@ def visit_Aexp(self, aexp: Aexp) -> list[Aexp]: *(a.copy() for a in post_args), ) ) - + return ret return [] def visit_HoleGate(self, gate: HoleGate) -> list[Gate]: @@ -63,7 +63,7 @@ def visit_Gate(self, gate: Gate) -> list[Gate]: return visitor(gate) for i, child in enumerate(gate.children): - if child.filled: + if child.terminated: continue next_children = self.visit_Aexp(child) pre_args = gate.children[:i] @@ -84,13 +84,13 @@ def visit_HoleCmd(self, cmd: HoleCmd) -> list[Cmd]: return [SeqCmd(), ForCmd(var=f"i{self.for_depth}"), GateCmd()] def visit_SeqCmd(self, cmd: SeqCmd) -> list[Cmd]: - if not cmd.pre.filled: + if not cmd.pre.terminated: pres = self.visit_Cmd(cmd.pre) ret: list[Cmd] = [] for pre in pres: ret.append(SeqCmd(pre=pre, post=cmd.post.copy())) return ret - if not cmd.post.filled: + if not cmd.post.terminated: posts = self.visit_Cmd(cmd.post) ret: list[Cmd] = [] for post in posts: @@ -99,7 +99,7 @@ def visit_SeqCmd(self, cmd: SeqCmd) -> list[Cmd]: return [] def visit_ForCmd(self, cmd: ForCmd) -> list[Cmd]: - if not cmd.start.filled: + if not cmd.start.terminated: starts = self.visit_Aexp(cmd.start) ret: list[Cmd] = [] for start in starts: @@ -112,7 +112,7 @@ def visit_ForCmd(self, cmd: ForCmd) -> list[Cmd]: ) ) return ret - if not cmd.end.filled: + if not cmd.end.terminated: ends = self.visit_Aexp(cmd.end) ret: list[Cmd] = [] for end in ends: @@ -125,7 +125,7 @@ def visit_ForCmd(self, cmd: ForCmd) -> list[Cmd]: ) ) return ret - if not cmd.body.filled: + if not cmd.body.terminated: self.for_depth += 1 bodies = self.visit_Cmd(cmd.body) self.for_depth -= 1 diff --git a/tests/test_transition.py b/tests/test_transition.py index 21010ec..fc2a4eb 100644 --- a/tests/test_transition.py +++ b/tests/test_transition.py @@ -2,12 +2,14 @@ ALL_AEXPS, ALL_GATES, CX, + Add, Aexp, Cmd, ForCmd, Gate, GateCmd, H, + HoleAexp, Integer, Pgm, SeqCmd, @@ -78,3 +80,18 @@ def test_hole_aexp3(): if type(pgm.body.gate.qreg2) not in [Integer, Var]: aexp_types.remove(type(pgm.body.gate.qreg2)) assert len(aexp_types) == 3 + + +def test_next_aexp(): + pgm = Pgm("n", GateCmd(CX(Integer(0), Add()))) + aexp_types: list[type[Aexp]] = ALL_AEXPS.copy() + pgms = next(pgm) + for pgm in pgms: + assert isinstance(pgm.body, GateCmd) + assert isinstance(pgm.body.gate, CX) + assert isinstance(pgm.body.gate.qreg2, Add) + assert type(pgm.body.gate.qreg2.a) in aexp_types + if type(pgm.body.gate.qreg2.a) not in [Integer, Var]: + aexp_types.remove(type(pgm.body.gate.qreg2.a)) + assert isinstance(pgm.body.gate.qreg2.b, HoleAexp) + assert len(aexp_types) == 3