Skip to content

Commit

Permalink
Add Search (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
glorialeezero authored Nov 26, 2024
2 parents c2f038f + fad84c6 commit e421144
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 10 deletions.
8 changes: 7 additions & 1 deletion qupsy/bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import cast

from qupsy.search import search
from qupsy.spec import parse_spec
from qupsy.utils import logger

Expand Down Expand Up @@ -29,4 +30,9 @@ def main() -> None:
if args.dry_run:
return

print("Hello, qupsy!")
logger.info("Searching for a solution")
try:
pgm = search(spec)
logger.info("Solution found: %s", pgm)
except ValueError:
logger.error("No solution found")
30 changes: 28 additions & 2 deletions qupsy/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,10 @@ def children(self) -> list[Aexp]:
return [self.a, self.b]

def __call__(self, memory: dict[str, int]) -> int:
return self.a(memory) // self.b(memory)
b = self.b(memory)
if b == 0:
raise ValueError("Division by zero")
return self.a(memory) // b


@dataclass
Expand Down Expand Up @@ -327,6 +330,8 @@ def children(self) -> list[Aexp]:

def __call__(self, qbits: list[LineQubit], memory: dict[str, int]) -> Gate:
idx = self.qreg(memory)
if idx >= len(qbits) or idx < 0:
raise ValueError(f"Index out of range: {idx} >= {len(qbits)}")
return cirq.H(qbits[idx]) # type: ignore


Expand All @@ -353,6 +358,8 @@ def children(self) -> list[Aexp]:

def __call__(self, qbits: list[LineQubit], memory: dict[str, int]) -> Gate:
idx = self.qreg(memory)
if idx >= len(qbits) or idx < 0:
raise ValueError(f"Index out of range: {idx} >= {len(qbits)}")
return cirq.X(qbits[idx]) # type: ignore


Expand Down Expand Up @@ -385,6 +392,8 @@ def children(self) -> list[Aexp]:

def __call__(self, qbits: list[LineQubit], memory: dict[str, int]) -> Gate:
idx = self.qreg(memory)
if idx >= len(qbits) or idx < 0:
raise ValueError(f"Index out of range: {idx} >= {len(qbits)}")
return cirq.Ry(rads=2 * np.arccos(np.sqrt(self.p(memory) / self.q(memory))))(qbits[idx]) # type: ignore


Expand Down Expand Up @@ -414,6 +423,12 @@ def children(self) -> list[Aexp]:
def __call__(self, qbits: list[LineQubit], memory: dict[str, int]) -> Gate:
idx1 = self.qreg1(memory)
idx2 = self.qreg2(memory)
if idx1 >= len(qbits) or idx1 < 0:
raise ValueError(f"Index out of range: {idx1} >= {len(qbits)}")
if idx2 >= len(qbits) or idx2 < 0:
raise ValueError(f"Index out of range: {idx2} >= {len(qbits)}")
if idx1 == idx2:
raise ValueError("Control and target qubits must be different")
return cirq.CX(qbits[idx1], qbits[idx2]) # type: ignore


Expand Down Expand Up @@ -453,6 +468,12 @@ def children(self) -> list[Aexp]:
def __call__(self, qbits: list[LineQubit], memory: dict[str, int]) -> Gate:
idx1 = self.qreg1(memory)
idx2 = self.qreg2(memory)
if idx1 >= len(qbits) or idx1 < 0:
raise ValueError(f"Index out of range: {idx1} >= {len(qbits)}")
if idx2 >= len(qbits) or idx2 < 0:
raise ValueError(f"Index out of range: {idx2} >= {len(qbits)}")
if idx1 == idx2:
raise ValueError("Control and target qubits must be different")
return cirq.Ry(rads=2 * np.arccos(np.sqrt(self.p(memory) / self.q(memory)))).controlled(num_controls=1)(qbits[idx1], qbits[idx2]) # type: ignore


Expand Down Expand Up @@ -601,7 +622,8 @@ def __call__(
for i in range(start, end):
memory[self.var] = i
self.body(qc, qbits, memory)
del memory[self.var]
if self.var in memory:
del memory[self.var]


@dataclass
Expand Down Expand Up @@ -658,6 +680,10 @@ def cost(self) -> int:
def depth(self) -> int:
return self.body.depth

@property
def terminated(self) -> bool:
return self.body.terminated

def __call__(self, n: int) -> Circuit:
circuit = Circuit()
qbits = LineQubit.range(n) # type: ignore
Expand Down
26 changes: 26 additions & 0 deletions qupsy/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from qupsy.language import Pgm
from qupsy.spec import Spec
from qupsy.transition import next
from qupsy.utils import logger
from qupsy.verify import verify
from qupsy.worklist import Worklist


def search(spec: Spec, *, initial_pgm: Pgm | None = None, n: str = "n") -> Pgm:
worklist = Worklist()
worklist.put(initial_pgm or Pgm(n))
while worklist.notEmpty():
current_pgm = worklist.get()
logger.debug(f"Current program: {current_pgm}")
verified: list[bool] = []
for testcase in spec.testcases:
if current_pgm.terminated:
if verify(testcase, current_pgm):
verified.append(True)
if len(verified) == len(spec.testcases):
return current_pgm
else:
break
else:
worklist.put(*next(current_pgm, spec.gates))
raise ValueError("No solution found")
11 changes: 7 additions & 4 deletions qupsy/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@


class TransitionVisitor:
def __init__(self) -> None:

def __init__(self, components: list[type[Gate]]):
self.n = ""
self.for_depth = 0
self.components = components

def visit_HoleAexp(self, aexp: HoleAexp) -> list[Aexp]:
return (
Expand Down Expand Up @@ -54,7 +56,7 @@ def visit_Aexp(self, aexp: Aexp) -> list[Aexp]:
return []

def visit_HoleGate(self, gate: HoleGate) -> list[Gate]:
return [g() for g in ALL_GATES if g != HoleGate]
return [g() for g in self.components if g != HoleGate]

def visit_Gate(self, gate: Gate) -> list[Gate]:
visitor_name = f"visit_{gate.__class__.__name__}"
Expand Down Expand Up @@ -163,6 +165,7 @@ def visit(self, pgm: Pgm) -> list[Pgm]:
return [Pgm(pgm.n, body) for body in bodies]


def next(pgm: Pgm) -> list[Pgm]:
visitor = TransitionVisitor()
def next(pgm: Pgm, components: list[type[Gate]] | None = None) -> list[Pgm]:
components = components or ALL_GATES
visitor = TransitionVisitor(components)
return visitor.visit(pgm)
9 changes: 6 additions & 3 deletions qupsy/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ def verify(testcase: tuple[npt.ArrayLike, npt.ArrayLike], pgm: Pgm) -> bool:
assert isinstance(input, np.ndarray)
assert isinstance(output, np.ndarray)
n = int(np.log2(input.size))
pgm_qc = pgm(n)
pgm_sv = cirq.final_state_vector(pgm_qc, initial_state=input, qubit_order=cirq.LineQubit.range(n)) # type: ignore
return cirq.linalg.allclose_up_to_global_phase(output, pgm_sv) # type: ignore
try:
pgm_qc = pgm(n)
pgm_sv = cirq.final_state_vector(pgm_qc, initial_state=input, qubit_order=cirq.LineQubit.range(n)) # type: ignore
return cirq.linalg.allclose_up_to_global_phase(output, pgm_sv) # type: ignore
except ValueError:
return False

0 comments on commit e421144

Please sign in to comment.