Skip to content

Commit

Permalink
support symbolic CtrlSpec
Browse files Browse the repository at this point in the history
  • Loading branch information
anurudhp committed Nov 9, 2024
1 parent 188b663 commit ee153f4
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 92 deletions.
73 changes: 55 additions & 18 deletions qualtran/_infra/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import numpy as np
from numpy.typing import NDArray

from ..symbolics import is_symbolic, prod, Shaped, SymbolicInt
from .bloq import Bloq, DecomposeNotImplementedError, DecomposeTypeError
from .data_types import QBit, QDType
from .gate_with_registers import GateWithRegisters
Expand All @@ -55,18 +56,21 @@ def _cvs_convert(
int,
np.integer,
NDArray[np.integer],
Shaped,
Sequence[Union[int, np.integer]],
Sequence[Sequence[Union[int, np.integer]]],
Sequence[NDArray[np.integer]],
Sequence[Union[NDArray[np.integer], Shaped]],
]
) -> Tuple[NDArray[np.integer], ...]:
) -> Tuple[Union[NDArray[np.integer], Shaped], ...]:
if isinstance(cvs, Shaped):
return (cvs,)
if isinstance(cvs, (int, np.integer)):
return (np.array(cvs),)
if isinstance(cvs, np.ndarray):
return (cvs,)
if all(isinstance(cv, (int, np.integer)) for cv in cvs):
return (np.asarray(cvs),)
return tuple(np.asarray(cv) for cv in cvs)
return tuple(cv if isinstance(cv, Shaped) else np.asarray(cv) for cv in cvs)


@attrs.frozen(eq=False)
Expand Down Expand Up @@ -115,7 +119,9 @@ class CtrlSpec:
qdtypes: Tuple[QDType, ...] = attrs.field(
default=QBit(), converter=lambda qt: (qt,) if isinstance(qt, QDType) else tuple(qt)
)
cvs: Tuple[NDArray[np.integer], ...] = attrs.field(default=1, converter=_cvs_convert)
cvs: Tuple[Union[NDArray[np.integer], Shaped], ...] = attrs.field(
default=1, converter=_cvs_convert
)

def __attrs_post_init__(self):
assert len(self.qdtypes) == len(self.cvs)
Expand All @@ -125,19 +131,29 @@ def num_ctrl_reg(self) -> int:
return len(self.qdtypes)

@cached_property
def shapes(self) -> Tuple[Tuple[int, ...], ...]:
def shapes(self) -> Tuple[Tuple[SymbolicInt, ...], ...]:
"""Tuple of shapes of control registers represented by this CtrlSpec."""
return tuple(cv.shape for cv in self.cvs)

@cached_property
def num_qubits(self) -> int:
def concrete_shapes(self) -> tuple[tuple[int, ...], ...]:
"""Tuple of shapes of control registers represented by this CtrlSpec."""
shapes = self.shapes
if is_symbolic(*shapes):
raise ValueError(f"cannot get concrete shapes: found symbolic {self.shapes}")
return shapes # type: ignore

@cached_property
def num_qubits(self) -> SymbolicInt:
"""Total number of qubits required for control registers represented by this CtrlSpec."""
return sum(
dtype.num_qubits * int(np.prod(shape))
for dtype, shape in zip(self.qdtypes, self.shapes)
dtype.num_qubits * prod(shape) for dtype, shape in zip(self.qdtypes, self.shapes)
)

def activation_function_dtypes(self) -> Sequence[Tuple[QDType, Tuple[int, ...]]]:
def is_symbolic(self):
return is_symbolic(*self.qdtypes) or is_symbolic(*self.cvs)

def activation_function_dtypes(self) -> Sequence[Tuple[QDType, Tuple[SymbolicInt, ...]]]:
"""The data types that serve as input to the 'activation function'.
The activation function takes in (quantum) inputs of these types and shapes and determines
Expand Down Expand Up @@ -165,6 +181,8 @@ def is_active(self, *vals: 'ClassicalValT') -> bool:
Returns:
True if the specific input values evaluate to `True` for this CtrlSpec.
"""
if self.is_symbolic():
raise ValueError(f"Cannot compute activation for symbolic {self}")
if len(vals) != self.num_ctrl_reg:
raise ValueError(f"Incorrect number of inputs for {self}: {len(vals)}.")

Expand All @@ -180,19 +198,31 @@ def is_active(self, *vals: 'ClassicalValT') -> bool:
return True

def wire_symbol(self, i: int, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
# Return a circle for bits; a box otherwise.
from qualtran.drawing import Circle, TextBox

cvs = self.cvs[i]

if is_symbolic(cvs):
# control value is not given
return TextBox('ctrl')

# Return a circle for bits; a box otherwise.
cv = cvs[idx]
if reg.bitsize == 1:
cv = self.cvs[i][idx]
return Circle(filled=(cv == 1))

cv = self.cvs[i][idx]
return TextBox(f'{cv}')
else:
return TextBox(f'{cv}')

@cached_property
def _cvs_tuple(self) -> Tuple[int, ...]:
return tuple(cv for cvs in self.cvs for cv in tuple(cvs.reshape(-1)))
def __cvs_tuple(self) -> Tuple[Union[tuple[int, ...], Shaped], ...]:
"""Serialize the control values for hashing and equality checking."""

def _serialize(cvs) -> Union[tuple[int, ...], Shaped]:
if isinstance(cvs, Shaped):
return cvs
return tuple(cvs.reshape(-1))

return tuple(_serialize(cvs) for cvs in self.cvs)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, CtrlSpec):
Expand All @@ -201,18 +231,22 @@ def __eq__(self, other: Any) -> bool:
return (
other.qdtypes == self.qdtypes
and other.shapes == self.shapes
and other._cvs_tuple == self._cvs_tuple
and other.__cvs_tuple == self.__cvs_tuple
)

def __hash__(self):
return hash((self.qdtypes, self.shapes, self._cvs_tuple))
return hash((self.qdtypes, self.shapes, self.__cvs_tuple))

def to_cirq_cv(self) -> 'cirq.SumOfProducts':
"""Convert CtrlSpec to cirq.SumOfProducts representation of control values."""
import cirq

if self.is_symbolic():
raise ValueError(f"Cannot convert symbolic {self} to cirq control values.")

cirq_cv = []
for qdtype, cv in zip(self.qdtypes, self.cvs):
assert isinstance(cv, np.ndarray)
for idx in Register('', qdtype, cv.shape).all_idxs():
cirq_cv += [*qdtype.to_bits(cv[idx])]
return cirq.SumOfProducts([tuple(cirq_cv)])
Expand Down Expand Up @@ -256,11 +290,14 @@ def from_cirq_cv(

def get_single_ctrl_bit(self) -> ControlBit:
"""If controlled by a single qubit, return the control bit, otherwise raise"""
if self.is_symbolic():
raise ValueError(f"cannot get ctrl bit for symbolic {self}")
if self.num_qubits != 1:
raise ValueError(f"expected a single qubit control, got {self.num_qubits}")

(qdtype,) = self.qdtypes
(cv,) = self.cvs
assert isinstance(cv, np.ndarray)
(idx,) = Register('', qdtype, cv.shape).all_idxs()
(control_bit,) = qdtype.to_bits(cv[idx])

Expand Down
44 changes: 41 additions & 3 deletions qualtran/_infra/controlled_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import attrs
import numpy as np
import pytest
import sympy

import qualtran.testing as qlt_testing
from qualtran import (
Expand All @@ -24,6 +25,7 @@
CompositeBloq,
Controlled,
CtrlSpec,
DecomposeTypeError,
QBit,
QInt,
QUInt,
Expand Down Expand Up @@ -52,6 +54,7 @@
from qualtran.drawing import get_musical_score_data
from qualtran.drawing.musical_score import Circle, SoqData, TextBox
from qualtran.simulation.tensor import cbloq_to_quimb, get_right_and_left_inds
from qualtran.symbolics import Shaped

if TYPE_CHECKING:
import cirq
Expand All @@ -73,8 +76,10 @@ def test_ctrl_spec():
cspec3 = CtrlSpec(QInt(64), cvs=np.int64(234234))
assert cspec3 != cspec1
assert cspec3.qdtypes[0].num_qubits == 64
assert cspec3.cvs[0] == 234234
assert cspec3.cvs[0][tuple()] == 234234
(cvs,) = cspec3.cvs
assert isinstance(cvs, np.ndarray)
assert cvs == 234234
assert cvs[tuple()] == 234234


def test_ctrl_spec_shape():
Expand All @@ -97,7 +102,9 @@ def test_ctrl_spec_to_cirq_cv_roundtrip():

for ctrl_spec in ctrl_specs:
assert ctrl_spec.to_cirq_cv() == cirq_cv.expand()
assert CtrlSpec.from_cirq_cv(cirq_cv, qdtypes=ctrl_spec.qdtypes, shapes=ctrl_spec.shapes)
assert CtrlSpec.from_cirq_cv(
cirq_cv, qdtypes=ctrl_spec.qdtypes, shapes=ctrl_spec.concrete_shapes
)


@pytest.mark.parametrize(
Expand All @@ -120,6 +127,32 @@ def test_ctrl_spec_single_bit_raises(ctrl_spec: CtrlSpec):
ctrl_spec.get_single_ctrl_bit()


@pytest.mark.parametrize("shape", [(1,), (10,), (10, 10)])
def test_ctrl_spec_symbolic_cvs(shape: tuple[int, ...]):
ctrl_spec = CtrlSpec(cvs=Shaped(shape))
assert ctrl_spec.is_symbolic()
assert ctrl_spec.num_qubits == np.prod(shape)
assert ctrl_spec.shapes == (shape,)


@pytest.mark.parametrize("shape", [(1,), (10,), (10, 10)])
def test_ctrl_spec_symbolic_dtype(shape: tuple[int, ...]):
n = sympy.Symbol("n")
dtype = QUInt(n)

ctrl_spec = CtrlSpec(qdtypes=dtype, cvs=Shaped(shape))

assert ctrl_spec.is_symbolic()
assert ctrl_spec.num_qubits == n * np.prod(shape)
assert ctrl_spec.shapes == (shape,)


def test_ctrl_spec_symbolic_wire_symbol():
ctrl_spec = CtrlSpec(cvs=Shaped((10,)))
reg = Register('q', QBit())
assert ctrl_spec.wire_symbol(0, reg) == TextBox('ctrl')


def _test_cirq_equivalence(bloq: Bloq, gate: 'cirq.Gate'):
import cirq

Expand Down Expand Up @@ -431,11 +464,15 @@ def signature(self) -> 'Signature':
return Signature([Register('x', QBit(), shape=(3,), side=Side.RIGHT)])

def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']:
if self.ctrl_spec.is_symbolic():
raise DecomposeTypeError(f"cannot decompose {self} with symbolic {self.ctrl_spec=}")

one_or_zero = [ZeroState(), OneState()]
ctrl_bloq = Controlled(And(*self.and_ctrl), ctrl_spec=self.ctrl_spec)

ctrl_soqs = {}
for reg, cvs in zip(ctrl_bloq.ctrl_regs, self.ctrl_spec.cvs):
assert isinstance(cvs, np.ndarray)
soqs = np.empty(shape=reg.shape, dtype=object)
for idx in reg.all_idxs():
soqs[idx] = bb.add(IntState(val=cvs[idx], bitsize=reg.dtype.num_qubits))
Expand All @@ -447,6 +484,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']:
out_soqs = np.asarray([*ctrl_soqs.pop('ctrl'), ctrl_soqs.pop('target')]) # type: ignore[misc]

for reg, cvs in zip(ctrl_bloq.ctrl_regs, self.ctrl_spec.cvs):
assert isinstance(cvs, np.ndarray)
for idx in reg.all_idxs():
ctrl_soq = np.asarray(ctrl_soqs[reg.name])[idx]
bb.add(IntEffect(val=cvs[idx], bitsize=reg.dtype.num_qubits), val=ctrl_soq)
Expand Down
2 changes: 2 additions & 0 deletions qualtran/bloqs/mcmt/ctrl_spec_and.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from functools import cached_property
from typing import Optional, TYPE_CHECKING, Union

import numpy as np
from attrs import frozen

from qualtran import (
Expand Down Expand Up @@ -123,6 +124,7 @@ def _flat_cvs(self) -> Union[tuple[int, ...], HasLength]:

flat_cvs: list[int] = []
for reg, cv in zip(self.control_registers, self.ctrl_spec.cvs):
assert isinstance(cv, np.ndarray)
flat_cvs.extend(reg.dtype.to_bits_array(cv.ravel()).ravel())
return tuple(flat_cvs)

Expand Down
9 changes: 7 additions & 2 deletions qualtran/serialization/ctrl_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from qualtran import CtrlSpec
from qualtran.protos import ctrl_spec_pb2
from qualtran.serialization import args, data_types
from qualtran.symbolics import Shaped


def ctrl_spec_from_proto(spec: ctrl_spec_pb2.CtrlSpec) -> CtrlSpec:
Expand All @@ -25,7 +25,12 @@ def ctrl_spec_from_proto(spec: ctrl_spec_pb2.CtrlSpec) -> CtrlSpec:


def ctrl_spec_to_proto(spec: CtrlSpec) -> ctrl_spec_pb2.CtrlSpec:
def cvs_to_proto(cvs):
if isinstance(cvs, Shaped):
raise ValueError("cannot serialize Shaped")
return args.ndarray_to_proto(cvs)

return ctrl_spec_pb2.CtrlSpec(
qdtypes=[data_types.data_type_to_proto(dtype) for dtype in spec.qdtypes],
cvs=[args.ndarray_to_proto(cvs) for cvs in spec.cvs],
cvs=[cvs_to_proto(cvs) for cvs in spec.cvs],
)
4 changes: 4 additions & 0 deletions qualtran/simulation/tensor/_tensor_data_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,15 @@ def active_space_for_ctrl_spec(
Returns a tuple of indices/slices that can be used to address into the ndarray, representing
tensor data of shape `tensor_shape_from_signature(signature)`, and access the active subspace.
"""
if ctrl_spec.is_symbolic():
raise ValueError(f"cannot compute active space for symbolic {ctrl_spec=}")

out_ind, inp_ind = tensor_out_inp_shape_from_signature(signature)
data_shape = out_ind + inp_ind
active_idx: List[Union[int, slice]] = [slice(x) for x in data_shape]
ctrl_idx = 0
for cv in ctrl_spec.cvs:
assert isinstance(cv, np.ndarray)
for idx in itertools.product(*[range(sh) for sh in cv.shape]):
active_idx[ctrl_idx] = int(cv[idx])
active_idx[ctrl_idx + len(out_ind)] = int(cv[idx])
Expand Down
4 changes: 2 additions & 2 deletions qualtran/symbolics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
sarg,
sconj,
sexp,
shape,
slen,
smax,
smin,
ssqrt,
Expand All @@ -38,7 +36,9 @@
from qualtran.symbolics.types import (
HasLength,
is_symbolic,
shape,
Shaped,
slen,
SymbolicComplex,
SymbolicFloat,
SymbolicInt,
Expand Down
Loading

0 comments on commit ee153f4

Please sign in to comment.