Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support symbolic CtrlSpec #1491

Merged
merged 2 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 55 additions & 18 deletions qualtran/_infra/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import numpy as np
from numpy.typing import NDArray

from ..symbolics import is_symbolic, prod, Shaped, SymbolicInt
from .bloq import Bloq, DecomposeNotImplementedError, DecomposeTypeError
from .data_types import QBit, QDType
from .gate_with_registers import GateWithRegisters
Expand All @@ -55,18 +56,21 @@ def _cvs_convert(
int,
np.integer,
NDArray[np.integer],
Shaped,
Sequence[Union[int, np.integer]],
Sequence[Sequence[Union[int, np.integer]]],
Sequence[NDArray[np.integer]],
Sequence[Union[NDArray[np.integer], Shaped]],
]
) -> Tuple[NDArray[np.integer], ...]:
) -> Tuple[Union[NDArray[np.integer], Shaped], ...]:
if isinstance(cvs, Shaped):
return (cvs,)
if isinstance(cvs, (int, np.integer)):
return (np.array(cvs),)
if isinstance(cvs, np.ndarray):
return (cvs,)
if all(isinstance(cv, (int, np.integer)) for cv in cvs):
return (np.asarray(cvs),)
return tuple(np.asarray(cv) for cv in cvs)
return tuple(cv if isinstance(cv, Shaped) else np.asarray(cv) for cv in cvs)


@attrs.frozen(eq=False)
Expand Down Expand Up @@ -115,7 +119,9 @@ class CtrlSpec:
qdtypes: Tuple[QDType, ...] = attrs.field(
default=QBit(), converter=lambda qt: (qt,) if isinstance(qt, QDType) else tuple(qt)
)
cvs: Tuple[NDArray[np.integer], ...] = attrs.field(default=1, converter=_cvs_convert)
cvs: Tuple[Union[NDArray[np.integer], Shaped], ...] = attrs.field(
default=1, converter=_cvs_convert
)

def __attrs_post_init__(self):
assert len(self.qdtypes) == len(self.cvs)
Expand All @@ -125,19 +131,29 @@ def num_ctrl_reg(self) -> int:
return len(self.qdtypes)

@cached_property
def shapes(self) -> Tuple[Tuple[int, ...], ...]:
def shapes(self) -> Tuple[Tuple[SymbolicInt, ...], ...]:
"""Tuple of shapes of control registers represented by this CtrlSpec."""
return tuple(cv.shape for cv in self.cvs)

@cached_property
def num_qubits(self) -> int:
def concrete_shapes(self) -> tuple[tuple[int, ...], ...]:
"""Tuple of shapes of control registers represented by this CtrlSpec."""
shapes = self.shapes
if is_symbolic(*shapes):
raise ValueError(f"cannot get concrete shapes: found symbolic {self.shapes}")
return shapes # type: ignore

@cached_property
def num_qubits(self) -> SymbolicInt:
"""Total number of qubits required for control registers represented by this CtrlSpec."""
return sum(
dtype.num_qubits * int(np.prod(shape))
for dtype, shape in zip(self.qdtypes, self.shapes)
dtype.num_qubits * prod(shape) for dtype, shape in zip(self.qdtypes, self.shapes)
)

def activation_function_dtypes(self) -> Sequence[Tuple[QDType, Tuple[int, ...]]]:
def is_symbolic(self):
return is_symbolic(*self.qdtypes) or is_symbolic(*self.cvs)

def activation_function_dtypes(self) -> Sequence[Tuple[QDType, Tuple[SymbolicInt, ...]]]:
"""The data types that serve as input to the 'activation function'.
The activation function takes in (quantum) inputs of these types and shapes and determines
Expand Down Expand Up @@ -165,6 +181,8 @@ def is_active(self, *vals: 'ClassicalValT') -> bool:
Returns:
True if the specific input values evaluate to `True` for this CtrlSpec.
"""
if self.is_symbolic():
raise ValueError(f"Cannot compute activation for symbolic {self}")
if len(vals) != self.num_ctrl_reg:
raise ValueError(f"Incorrect number of inputs for {self}: {len(vals)}.")

Expand All @@ -180,19 +198,31 @@ def is_active(self, *vals: 'ClassicalValT') -> bool:
return True

def wire_symbol(self, i: int, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
# Return a circle for bits; a box otherwise.
from qualtran.drawing import Circle, TextBox

cvs = self.cvs[i]

if is_symbolic(cvs):
# control value is not given
return TextBox('ctrl')

# Return a circle for bits; a box otherwise.
cv = cvs[idx]
if reg.bitsize == 1:
cv = self.cvs[i][idx]
return Circle(filled=(cv == 1))

cv = self.cvs[i][idx]
return TextBox(f'{cv}')
else:
return TextBox(f'{cv}')

@cached_property
def _cvs_tuple(self) -> Tuple[int, ...]:
return tuple(cv for cvs in self.cvs for cv in tuple(cvs.reshape(-1)))
def __cvs_tuple(self) -> Tuple[Union[tuple[int, ...], Shaped], ...]:
"""Serialize the control values for hashing and equality checking."""

def _serialize(cvs) -> Union[tuple[int, ...], Shaped]:
if isinstance(cvs, Shaped):
return cvs
return tuple(cvs.reshape(-1))

return tuple(_serialize(cvs) for cvs in self.cvs)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, CtrlSpec):
Expand All @@ -201,18 +231,22 @@ def __eq__(self, other: Any) -> bool:
return (
other.qdtypes == self.qdtypes
and other.shapes == self.shapes
and other._cvs_tuple == self._cvs_tuple
and other.__cvs_tuple == self.__cvs_tuple
)

def __hash__(self):
return hash((self.qdtypes, self.shapes, self._cvs_tuple))
return hash((self.qdtypes, self.shapes, self.__cvs_tuple))

def to_cirq_cv(self) -> 'cirq.SumOfProducts':
"""Convert CtrlSpec to cirq.SumOfProducts representation of control values."""
import cirq

if self.is_symbolic():
raise ValueError(f"Cannot convert symbolic {self} to cirq control values.")

cirq_cv = []
for qdtype, cv in zip(self.qdtypes, self.cvs):
assert isinstance(cv, np.ndarray)
for idx in Register('', qdtype, cv.shape).all_idxs():
cirq_cv += [*qdtype.to_bits(cv[idx])]
return cirq.SumOfProducts([tuple(cirq_cv)])
Expand Down Expand Up @@ -256,11 +290,14 @@ def from_cirq_cv(

def get_single_ctrl_bit(self) -> ControlBit:
"""If controlled by a single qubit, return the control bit, otherwise raise"""
if self.is_symbolic():
raise ValueError(f"cannot get ctrl bit for symbolic {self}")
if self.num_qubits != 1:
raise ValueError(f"expected a single qubit control, got {self.num_qubits}")

(qdtype,) = self.qdtypes
(cv,) = self.cvs
assert isinstance(cv, np.ndarray)
(idx,) = Register('', qdtype, cv.shape).all_idxs()
(control_bit,) = qdtype.to_bits(cv[idx])

Expand Down
44 changes: 41 additions & 3 deletions qualtran/_infra/controlled_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import attrs
import numpy as np
import pytest
import sympy

import qualtran.testing as qlt_testing
from qualtran import (
Expand All @@ -24,6 +25,7 @@
CompositeBloq,
Controlled,
CtrlSpec,
DecomposeTypeError,
QBit,
QInt,
QUInt,
Expand Down Expand Up @@ -52,6 +54,7 @@
from qualtran.drawing import get_musical_score_data
from qualtran.drawing.musical_score import Circle, SoqData, TextBox
from qualtran.simulation.tensor import cbloq_to_quimb, get_right_and_left_inds
from qualtran.symbolics import Shaped

if TYPE_CHECKING:
import cirq
Expand All @@ -73,8 +76,10 @@ def test_ctrl_spec():
cspec3 = CtrlSpec(QInt(64), cvs=np.int64(234234))
assert cspec3 != cspec1
assert cspec3.qdtypes[0].num_qubits == 64
assert cspec3.cvs[0] == 234234
assert cspec3.cvs[0][tuple()] == 234234
(cvs,) = cspec3.cvs
assert isinstance(cvs, np.ndarray)
assert cvs == 234234
assert cvs[tuple()] == 234234


def test_ctrl_spec_shape():
Expand All @@ -97,7 +102,9 @@ def test_ctrl_spec_to_cirq_cv_roundtrip():

for ctrl_spec in ctrl_specs:
assert ctrl_spec.to_cirq_cv() == cirq_cv.expand()
assert CtrlSpec.from_cirq_cv(cirq_cv, qdtypes=ctrl_spec.qdtypes, shapes=ctrl_spec.shapes)
assert CtrlSpec.from_cirq_cv(
cirq_cv, qdtypes=ctrl_spec.qdtypes, shapes=ctrl_spec.concrete_shapes
)


@pytest.mark.parametrize(
Expand All @@ -120,6 +127,32 @@ def test_ctrl_spec_single_bit_raises(ctrl_spec: CtrlSpec):
ctrl_spec.get_single_ctrl_bit()


@pytest.mark.parametrize("shape", [(1,), (10,), (10, 10)])
def test_ctrl_spec_symbolic_cvs(shape: tuple[int, ...]):
ctrl_spec = CtrlSpec(cvs=Shaped(shape))
assert ctrl_spec.is_symbolic()
assert ctrl_spec.num_qubits == np.prod(shape)
assert ctrl_spec.shapes == (shape,)


@pytest.mark.parametrize("shape", [(1,), (10,), (10, 10)])
def test_ctrl_spec_symbolic_dtype(shape: tuple[int, ...]):
n = sympy.Symbol("n")
dtype = QUInt(n)

ctrl_spec = CtrlSpec(qdtypes=dtype, cvs=Shaped(shape))

assert ctrl_spec.is_symbolic()
assert ctrl_spec.num_qubits == n * np.prod(shape)
assert ctrl_spec.shapes == (shape,)


def test_ctrl_spec_symbolic_wire_symbol():
ctrl_spec = CtrlSpec(cvs=Shaped((10,)))
reg = Register('q', QBit())
assert ctrl_spec.wire_symbol(0, reg) == TextBox('ctrl')


def _test_cirq_equivalence(bloq: Bloq, gate: 'cirq.Gate'):
import cirq

Expand Down Expand Up @@ -431,11 +464,15 @@ def signature(self) -> 'Signature':
return Signature([Register('x', QBit(), shape=(3,), side=Side.RIGHT)])

def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']:
if self.ctrl_spec.is_symbolic():
raise DecomposeTypeError(f"cannot decompose {self} with symbolic {self.ctrl_spec=}")

one_or_zero = [ZeroState(), OneState()]
ctrl_bloq = Controlled(And(*self.and_ctrl), ctrl_spec=self.ctrl_spec)

ctrl_soqs = {}
for reg, cvs in zip(ctrl_bloq.ctrl_regs, self.ctrl_spec.cvs):
assert isinstance(cvs, np.ndarray)
soqs = np.empty(shape=reg.shape, dtype=object)
for idx in reg.all_idxs():
soqs[idx] = bb.add(IntState(val=cvs[idx], bitsize=reg.dtype.num_qubits))
Expand All @@ -447,6 +484,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']:
out_soqs = np.asarray([*ctrl_soqs.pop('ctrl'), ctrl_soqs.pop('target')]) # type: ignore[misc]

for reg, cvs in zip(ctrl_bloq.ctrl_regs, self.ctrl_spec.cvs):
assert isinstance(cvs, np.ndarray)
for idx in reg.all_idxs():
ctrl_soq = np.asarray(ctrl_soqs[reg.name])[idx]
bb.add(IntEffect(val=cvs[idx], bitsize=reg.dtype.num_qubits), val=ctrl_soq)
Expand Down
Loading
Loading