From ee153f49bc74ceb5dae3a70c59cfde2f3ea53121 Mon Sep 17 00:00:00 2001 From: Anurudh Peduri Date: Wed, 6 Nov 2024 21:11:18 +0100 Subject: [PATCH] support symbolic `CtrlSpec` --- qualtran/_infra/controlled.py | 73 ++++++++++++++----- qualtran/_infra/controlled_test.py | 44 ++++++++++- qualtran/bloqs/mcmt/ctrl_spec_and.py | 2 + qualtran/serialization/ctrl_spec.py | 9 ++- .../tensor/_tensor_data_manipulation.py | 4 + qualtran/symbolics/__init__.py | 4 +- qualtran/symbolics/math_funcs.py | 39 +--------- qualtran/symbolics/math_funcs_test.py | 24 +----- qualtran/symbolics/types.py | 59 +++++++++++++-- qualtran/symbolics/types_test.py | 28 +++++++ 10 files changed, 194 insertions(+), 92 deletions(-) create mode 100644 qualtran/symbolics/types_test.py diff --git a/qualtran/_infra/controlled.py b/qualtran/_infra/controlled.py index dca518e93..bcfc1ee6f 100644 --- a/qualtran/_infra/controlled.py +++ b/qualtran/_infra/controlled.py @@ -31,6 +31,7 @@ import numpy as np from numpy.typing import NDArray +from ..symbolics import is_symbolic, prod, Shaped, SymbolicInt from .bloq import Bloq, DecomposeNotImplementedError, DecomposeTypeError from .data_types import QBit, QDType from .gate_with_registers import GateWithRegisters @@ -55,18 +56,21 @@ def _cvs_convert( int, np.integer, NDArray[np.integer], + Shaped, Sequence[Union[int, np.integer]], Sequence[Sequence[Union[int, np.integer]]], - Sequence[NDArray[np.integer]], + Sequence[Union[NDArray[np.integer], Shaped]], ] -) -> Tuple[NDArray[np.integer], ...]: +) -> Tuple[Union[NDArray[np.integer], Shaped], ...]: + if isinstance(cvs, Shaped): + return (cvs,) if isinstance(cvs, (int, np.integer)): return (np.array(cvs),) if isinstance(cvs, np.ndarray): return (cvs,) if all(isinstance(cv, (int, np.integer)) for cv in cvs): return (np.asarray(cvs),) - return tuple(np.asarray(cv) for cv in cvs) + return tuple(cv if isinstance(cv, Shaped) else np.asarray(cv) for cv in cvs) @attrs.frozen(eq=False) @@ -115,7 +119,9 @@ class CtrlSpec: qdtypes: Tuple[QDType, ...] = attrs.field( default=QBit(), converter=lambda qt: (qt,) if isinstance(qt, QDType) else tuple(qt) ) - cvs: Tuple[NDArray[np.integer], ...] = attrs.field(default=1, converter=_cvs_convert) + cvs: Tuple[Union[NDArray[np.integer], Shaped], ...] = attrs.field( + default=1, converter=_cvs_convert + ) def __attrs_post_init__(self): assert len(self.qdtypes) == len(self.cvs) @@ -125,19 +131,29 @@ def num_ctrl_reg(self) -> int: return len(self.qdtypes) @cached_property - def shapes(self) -> Tuple[Tuple[int, ...], ...]: + def shapes(self) -> Tuple[Tuple[SymbolicInt, ...], ...]: """Tuple of shapes of control registers represented by this CtrlSpec.""" return tuple(cv.shape for cv in self.cvs) @cached_property - def num_qubits(self) -> int: + def concrete_shapes(self) -> tuple[tuple[int, ...], ...]: + """Tuple of shapes of control registers represented by this CtrlSpec.""" + shapes = self.shapes + if is_symbolic(*shapes): + raise ValueError(f"cannot get concrete shapes: found symbolic {self.shapes}") + return shapes # type: ignore + + @cached_property + def num_qubits(self) -> SymbolicInt: """Total number of qubits required for control registers represented by this CtrlSpec.""" return sum( - dtype.num_qubits * int(np.prod(shape)) - for dtype, shape in zip(self.qdtypes, self.shapes) + dtype.num_qubits * prod(shape) for dtype, shape in zip(self.qdtypes, self.shapes) ) - def activation_function_dtypes(self) -> Sequence[Tuple[QDType, Tuple[int, ...]]]: + def is_symbolic(self): + return is_symbolic(*self.qdtypes) or is_symbolic(*self.cvs) + + def activation_function_dtypes(self) -> Sequence[Tuple[QDType, Tuple[SymbolicInt, ...]]]: """The data types that serve as input to the 'activation function'. The activation function takes in (quantum) inputs of these types and shapes and determines @@ -165,6 +181,8 @@ def is_active(self, *vals: 'ClassicalValT') -> bool: Returns: True if the specific input values evaluate to `True` for this CtrlSpec. """ + if self.is_symbolic(): + raise ValueError(f"Cannot compute activation for symbolic {self}") if len(vals) != self.num_ctrl_reg: raise ValueError(f"Incorrect number of inputs for {self}: {len(vals)}.") @@ -180,19 +198,31 @@ def is_active(self, *vals: 'ClassicalValT') -> bool: return True def wire_symbol(self, i: int, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol': - # Return a circle for bits; a box otherwise. from qualtran.drawing import Circle, TextBox + cvs = self.cvs[i] + + if is_symbolic(cvs): + # control value is not given + return TextBox('ctrl') + + # Return a circle for bits; a box otherwise. + cv = cvs[idx] if reg.bitsize == 1: - cv = self.cvs[i][idx] return Circle(filled=(cv == 1)) - - cv = self.cvs[i][idx] - return TextBox(f'{cv}') + else: + return TextBox(f'{cv}') @cached_property - def _cvs_tuple(self) -> Tuple[int, ...]: - return tuple(cv for cvs in self.cvs for cv in tuple(cvs.reshape(-1))) + def __cvs_tuple(self) -> Tuple[Union[tuple[int, ...], Shaped], ...]: + """Serialize the control values for hashing and equality checking.""" + + def _serialize(cvs) -> Union[tuple[int, ...], Shaped]: + if isinstance(cvs, Shaped): + return cvs + return tuple(cvs.reshape(-1)) + + return tuple(_serialize(cvs) for cvs in self.cvs) def __eq__(self, other: Any) -> bool: if not isinstance(other, CtrlSpec): @@ -201,18 +231,22 @@ def __eq__(self, other: Any) -> bool: return ( other.qdtypes == self.qdtypes and other.shapes == self.shapes - and other._cvs_tuple == self._cvs_tuple + and other.__cvs_tuple == self.__cvs_tuple ) def __hash__(self): - return hash((self.qdtypes, self.shapes, self._cvs_tuple)) + return hash((self.qdtypes, self.shapes, self.__cvs_tuple)) def to_cirq_cv(self) -> 'cirq.SumOfProducts': """Convert CtrlSpec to cirq.SumOfProducts representation of control values.""" import cirq + if self.is_symbolic(): + raise ValueError(f"Cannot convert symbolic {self} to cirq control values.") + cirq_cv = [] for qdtype, cv in zip(self.qdtypes, self.cvs): + assert isinstance(cv, np.ndarray) for idx in Register('', qdtype, cv.shape).all_idxs(): cirq_cv += [*qdtype.to_bits(cv[idx])] return cirq.SumOfProducts([tuple(cirq_cv)]) @@ -256,11 +290,14 @@ def from_cirq_cv( def get_single_ctrl_bit(self) -> ControlBit: """If controlled by a single qubit, return the control bit, otherwise raise""" + if self.is_symbolic(): + raise ValueError(f"cannot get ctrl bit for symbolic {self}") if self.num_qubits != 1: raise ValueError(f"expected a single qubit control, got {self.num_qubits}") (qdtype,) = self.qdtypes (cv,) = self.cvs + assert isinstance(cv, np.ndarray) (idx,) = Register('', qdtype, cv.shape).all_idxs() (control_bit,) = qdtype.to_bits(cv[idx]) diff --git a/qualtran/_infra/controlled_test.py b/qualtran/_infra/controlled_test.py index 77d72432c..fcb9f207c 100644 --- a/qualtran/_infra/controlled_test.py +++ b/qualtran/_infra/controlled_test.py @@ -16,6 +16,7 @@ import attrs import numpy as np import pytest +import sympy import qualtran.testing as qlt_testing from qualtran import ( @@ -24,6 +25,7 @@ CompositeBloq, Controlled, CtrlSpec, + DecomposeTypeError, QBit, QInt, QUInt, @@ -52,6 +54,7 @@ from qualtran.drawing import get_musical_score_data from qualtran.drawing.musical_score import Circle, SoqData, TextBox from qualtran.simulation.tensor import cbloq_to_quimb, get_right_and_left_inds +from qualtran.symbolics import Shaped if TYPE_CHECKING: import cirq @@ -73,8 +76,10 @@ def test_ctrl_spec(): cspec3 = CtrlSpec(QInt(64), cvs=np.int64(234234)) assert cspec3 != cspec1 assert cspec3.qdtypes[0].num_qubits == 64 - assert cspec3.cvs[0] == 234234 - assert cspec3.cvs[0][tuple()] == 234234 + (cvs,) = cspec3.cvs + assert isinstance(cvs, np.ndarray) + assert cvs == 234234 + assert cvs[tuple()] == 234234 def test_ctrl_spec_shape(): @@ -97,7 +102,9 @@ def test_ctrl_spec_to_cirq_cv_roundtrip(): for ctrl_spec in ctrl_specs: assert ctrl_spec.to_cirq_cv() == cirq_cv.expand() - assert CtrlSpec.from_cirq_cv(cirq_cv, qdtypes=ctrl_spec.qdtypes, shapes=ctrl_spec.shapes) + assert CtrlSpec.from_cirq_cv( + cirq_cv, qdtypes=ctrl_spec.qdtypes, shapes=ctrl_spec.concrete_shapes + ) @pytest.mark.parametrize( @@ -120,6 +127,32 @@ def test_ctrl_spec_single_bit_raises(ctrl_spec: CtrlSpec): ctrl_spec.get_single_ctrl_bit() +@pytest.mark.parametrize("shape", [(1,), (10,), (10, 10)]) +def test_ctrl_spec_symbolic_cvs(shape: tuple[int, ...]): + ctrl_spec = CtrlSpec(cvs=Shaped(shape)) + assert ctrl_spec.is_symbolic() + assert ctrl_spec.num_qubits == np.prod(shape) + assert ctrl_spec.shapes == (shape,) + + +@pytest.mark.parametrize("shape", [(1,), (10,), (10, 10)]) +def test_ctrl_spec_symbolic_dtype(shape: tuple[int, ...]): + n = sympy.Symbol("n") + dtype = QUInt(n) + + ctrl_spec = CtrlSpec(qdtypes=dtype, cvs=Shaped(shape)) + + assert ctrl_spec.is_symbolic() + assert ctrl_spec.num_qubits == n * np.prod(shape) + assert ctrl_spec.shapes == (shape,) + + +def test_ctrl_spec_symbolic_wire_symbol(): + ctrl_spec = CtrlSpec(cvs=Shaped((10,))) + reg = Register('q', QBit()) + assert ctrl_spec.wire_symbol(0, reg) == TextBox('ctrl') + + def _test_cirq_equivalence(bloq: Bloq, gate: 'cirq.Gate'): import cirq @@ -431,11 +464,15 @@ def signature(self) -> 'Signature': return Signature([Register('x', QBit(), shape=(3,), side=Side.RIGHT)]) def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']: + if self.ctrl_spec.is_symbolic(): + raise DecomposeTypeError(f"cannot decompose {self} with symbolic {self.ctrl_spec=}") + one_or_zero = [ZeroState(), OneState()] ctrl_bloq = Controlled(And(*self.and_ctrl), ctrl_spec=self.ctrl_spec) ctrl_soqs = {} for reg, cvs in zip(ctrl_bloq.ctrl_regs, self.ctrl_spec.cvs): + assert isinstance(cvs, np.ndarray) soqs = np.empty(shape=reg.shape, dtype=object) for idx in reg.all_idxs(): soqs[idx] = bb.add(IntState(val=cvs[idx], bitsize=reg.dtype.num_qubits)) @@ -447,6 +484,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']: out_soqs = np.asarray([*ctrl_soqs.pop('ctrl'), ctrl_soqs.pop('target')]) # type: ignore[misc] for reg, cvs in zip(ctrl_bloq.ctrl_regs, self.ctrl_spec.cvs): + assert isinstance(cvs, np.ndarray) for idx in reg.all_idxs(): ctrl_soq = np.asarray(ctrl_soqs[reg.name])[idx] bb.add(IntEffect(val=cvs[idx], bitsize=reg.dtype.num_qubits), val=ctrl_soq) diff --git a/qualtran/bloqs/mcmt/ctrl_spec_and.py b/qualtran/bloqs/mcmt/ctrl_spec_and.py index 409bb5307..bc9d6f744 100644 --- a/qualtran/bloqs/mcmt/ctrl_spec_and.py +++ b/qualtran/bloqs/mcmt/ctrl_spec_and.py @@ -14,6 +14,7 @@ from functools import cached_property from typing import Optional, TYPE_CHECKING, Union +import numpy as np from attrs import frozen from qualtran import ( @@ -123,6 +124,7 @@ def _flat_cvs(self) -> Union[tuple[int, ...], HasLength]: flat_cvs: list[int] = [] for reg, cv in zip(self.control_registers, self.ctrl_spec.cvs): + assert isinstance(cv, np.ndarray) flat_cvs.extend(reg.dtype.to_bits_array(cv.ravel()).ravel()) return tuple(flat_cvs) diff --git a/qualtran/serialization/ctrl_spec.py b/qualtran/serialization/ctrl_spec.py index 301faf2fc..30895b51e 100644 --- a/qualtran/serialization/ctrl_spec.py +++ b/qualtran/serialization/ctrl_spec.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from qualtran import CtrlSpec from qualtran.protos import ctrl_spec_pb2 from qualtran.serialization import args, data_types +from qualtran.symbolics import Shaped def ctrl_spec_from_proto(spec: ctrl_spec_pb2.CtrlSpec) -> CtrlSpec: @@ -25,7 +25,12 @@ def ctrl_spec_from_proto(spec: ctrl_spec_pb2.CtrlSpec) -> CtrlSpec: def ctrl_spec_to_proto(spec: CtrlSpec) -> ctrl_spec_pb2.CtrlSpec: + def cvs_to_proto(cvs): + if isinstance(cvs, Shaped): + raise ValueError("cannot serialize Shaped") + return args.ndarray_to_proto(cvs) + return ctrl_spec_pb2.CtrlSpec( qdtypes=[data_types.data_type_to_proto(dtype) for dtype in spec.qdtypes], - cvs=[args.ndarray_to_proto(cvs) for cvs in spec.cvs], + cvs=[cvs_to_proto(cvs) for cvs in spec.cvs], ) diff --git a/qualtran/simulation/tensor/_tensor_data_manipulation.py b/qualtran/simulation/tensor/_tensor_data_manipulation.py index 97f8852e2..5029ae22e 100644 --- a/qualtran/simulation/tensor/_tensor_data_manipulation.py +++ b/qualtran/simulation/tensor/_tensor_data_manipulation.py @@ -68,11 +68,15 @@ def active_space_for_ctrl_spec( Returns a tuple of indices/slices that can be used to address into the ndarray, representing tensor data of shape `tensor_shape_from_signature(signature)`, and access the active subspace. """ + if ctrl_spec.is_symbolic(): + raise ValueError(f"cannot compute active space for symbolic {ctrl_spec=}") + out_ind, inp_ind = tensor_out_inp_shape_from_signature(signature) data_shape = out_ind + inp_ind active_idx: List[Union[int, slice]] = [slice(x) for x in data_shape] ctrl_idx = 0 for cv in ctrl_spec.cvs: + assert isinstance(cv, np.ndarray) for idx in itertools.product(*[range(sh) for sh in cv.shape]): active_idx[ctrl_idx] = int(cv[idx]) active_idx[ctrl_idx + len(out_ind)] = int(cv[idx]) diff --git a/qualtran/symbolics/__init__.py b/qualtran/symbolics/__init__.py index 9bee6b1fe..95cf224f0 100644 --- a/qualtran/symbolics/__init__.py +++ b/qualtran/symbolics/__init__.py @@ -28,8 +28,6 @@ sarg, sconj, sexp, - shape, - slen, smax, smin, ssqrt, @@ -38,7 +36,9 @@ from qualtran.symbolics.types import ( HasLength, is_symbolic, + shape, Shaped, + slen, SymbolicComplex, SymbolicFloat, SymbolicInt, diff --git a/qualtran/symbolics/math_funcs.py b/qualtran/symbolics/math_funcs.py index 19473ad20..5768042de 100644 --- a/qualtran/symbolics/math_funcs.py +++ b/qualtran/symbolics/math_funcs.py @@ -11,19 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Iterable, overload, Sized, Tuple, TypeVar, Union +from typing import cast, Iterable, overload, TypeVar import numpy as np import sympy -from qualtran.symbolics.types import ( - HasLength, - is_symbolic, - Shaped, - SymbolicComplex, - SymbolicFloat, - SymbolicInt, -) +from qualtran.symbolics.types import is_symbolic, SymbolicComplex, SymbolicFloat, SymbolicInt def pi(*args) -> SymbolicFloat: @@ -261,34 +254,6 @@ def sconj(x: SymbolicComplex) -> SymbolicComplex: return sympy.conjugate(x) if is_symbolic(x) else np.conjugate(x) -@overload -def slen(x: Sized) -> int: ... - - -@overload -def slen(x: Union[Shaped, HasLength]) -> sympy.Expr: ... - - -def slen(x: Union[Sized, Shaped, HasLength]) -> SymbolicInt: - if isinstance(x, Shaped): - return x.shape[0] - if isinstance(x, HasLength): - return x.n - return len(x) - - -@overload -def shape(x: np.ndarray) -> Tuple[int, ...]: ... - - -@overload -def shape(x: Shaped) -> Tuple[SymbolicInt, ...]: ... - - -def shape(x: Union[np.ndarray, Shaped]): - return x.shape - - def is_zero(x: SymbolicInt) -> bool: """check if a symbolic integer is zero diff --git a/qualtran/symbolics/math_funcs_test.py b/qualtran/symbolics/math_funcs_test.py index 993006d83..78be4eac0 100644 --- a/qualtran/symbolics/math_funcs_test.py +++ b/qualtran/symbolics/math_funcs_test.py @@ -18,19 +18,7 @@ import sympy from sympy.codegen.cfunctions import log2 as sympy_log2 -from qualtran.symbolics import ( - bit_length, - ceil, - is_symbolic, - is_zero, - log2, - sarg, - sexp, - Shaped, - slen, - smax, - smin, -) +from qualtran.symbolics import bit_length, ceil, is_zero, log2, sarg, sexp, smax, smin def test_log2(): @@ -130,16 +118,6 @@ def test_bit_length_symbolic_simplify(): assert b.subs({N: 2**n}) == n -@pytest.mark.parametrize( - "shape", - [(4,), (1, 2), (1, 2, 3), (sympy.Symbol('n'),), (sympy.Symbol('n'), sympy.Symbol('m'), 100)], -) -def test_shaped(shape: tuple[int, ...]): - shaped = Shaped(shape=shape) - assert is_symbolic(shaped) - assert slen(shaped) == shape[0] - - def test_is_zero(): assert is_zero(0) assert not is_zero(1) diff --git a/qualtran/symbolics/types.py b/qualtran/symbolics/types.py index 18714b6f8..e971db857 100644 --- a/qualtran/symbolics/types.py +++ b/qualtran/symbolics/types.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import overload, TypeVar, Union +from typing import overload, Sized, TypeVar, Union +import numpy as np import sympy from attrs import field, frozen, validators from typing_extensions import TypeIs @@ -31,13 +32,20 @@ class Shaped: """Symbolic value for an object that has a shape. - A Shaped object can be used as a symbolic replacement for any object that has an attribute `shape`, - for example numpy NDArrays. - Each dimension can be either an positive integer value or a sympy expression. + A Shaped object can be used as a symbolic replacement for any object that has an + attribute `shape`, for example numpy `NDArrays`. Each dimension can be either + a positive integer value or a sympy expression. - This is useful to do symbolic analysis of Bloqs whose call graph only depends on the shape of the input, - but not on the actual values. - For example, T-cost of the `QROM` Bloq depends only on the iteration length (shape) and not on actual data values. + For the symbolic variant of a tuple or sequence of values, see `HasLength`. + + This is useful to do symbolic analysis of Bloqs whose call graph only depends on the shape + of the input, but not on the actual values. For example, T-cost of the `QROM` Bloq depends + only on the iteration length (shape) and not on actual data values. In this case, for the + bloq attribute `data`, we can use the type: + + ```py + data: Union[NDArray, Shaped] + ``` """ shape: tuple[SymbolicInt, ...] = field(validator=validators.instance_of(tuple)) @@ -50,6 +58,15 @@ def is_symbolic(self): class HasLength: """Symbolic value for an object that has a length. + This is used as a "symbolic" tuple. The length can either be a positive integer + or a sympy expression. For example, if a bloq attribute is a tuple of ints, + we can use the type: + + ```py + values: Union[tuple, HasLength] + ``` + + For the symbolic variant of a NDArray, see `Shaped`. Note that we cannot override __len__ and return a sympy symbol because Python has special treatment for __len__ and expects you to return a non-negative integers. @@ -63,6 +80,34 @@ def is_symbolic(self): return True +@overload +def slen(x: Sized) -> int: ... + + +@overload +def slen(x: Union[Shaped, HasLength]) -> sympy.Expr: ... + + +def slen(x: Union[Sized, Shaped, HasLength]) -> SymbolicInt: + if isinstance(x, Shaped): + return x.shape[0] + if isinstance(x, HasLength): + return x.n + return len(x) + + +@overload +def shape(x: np.ndarray) -> tuple[int, ...]: ... + + +@overload +def shape(x: Shaped) -> tuple[SymbolicInt, ...]: ... + + +def shape(x: Union[np.ndarray, Shaped]): + return x.shape + + T = TypeVar('T') diff --git a/qualtran/symbolics/types_test.py b/qualtran/symbolics/types_test.py new file mode 100644 index 000000000..8e6656119 --- /dev/null +++ b/qualtran/symbolics/types_test.py @@ -0,0 +1,28 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import sympy + +from qualtran.symbolics import is_symbolic, Shaped, slen + + +@pytest.mark.parametrize( + "shape", + [(4,), (1, 2), (1, 2, 3), (sympy.Symbol('n'),), (sympy.Symbol('n'), sympy.Symbol('m'), 100)], +) +def test_shaped(shape: tuple[int, ...]): + shaped = Shaped(shape=shape) + assert is_symbolic(shaped) + assert slen(shaped) == shape[0]