From f0b0dc5240a3cf0ae94e085b6acf58b35ef02aed Mon Sep 17 00:00:00 2001 From: Anurudh Peduri Date: Tue, 29 Oct 2024 00:28:22 +0100 Subject: [PATCH] cleanup remaining uses of old control specialization interface --- qualtran/_infra/single_qubit_controlled.py | 78 ------------------- .../block_encoding/lcu_block_encoding.py | 66 ++++++---------- .../mean_estimation_operator.py | 46 ++++------- .../mean_estimation_operator_test.py | 7 +- .../qubitization_walk_operator_test.py | 61 ++++----------- 5 files changed, 59 insertions(+), 199 deletions(-) delete mode 100644 qualtran/_infra/single_qubit_controlled.py diff --git a/qualtran/_infra/single_qubit_controlled.py b/qualtran/_infra/single_qubit_controlled.py deleted file mode 100644 index 1166bd13f..000000000 --- a/qualtran/_infra/single_qubit_controlled.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -from typing import Iterable, Optional, Sequence, Tuple, TYPE_CHECKING - -import attrs - -from qualtran._infra.bloq import Bloq -from qualtran._infra.controlled import CtrlSpec -from qualtran._infra.registers import Register - -if TYPE_CHECKING: - from qualtran import AddControlledT, BloqBuilder, SoquetT - - -class SpecializedSingleQubitControlledExtension(Bloq): - """Add a specialized single-qubit controlled version of a Bloq. - - `control_val` is an optional single-bit control. When `control_val` is provided, - the `control_registers` property should return a single named qubit register, - and otherwise return an empty tuple. - - Example usage: - - @attrs.frozen - class MyGate(SpecializedSingleQubitControlledExtension): - control_val: Optional[int] = None - - @property - def control_registers() -> Tuple[Register, ...]: - return () if self.control_val is None else (Register('control', QBit()),) - """ - - control_val: Optional[int] - - @property - @abc.abstractmethod - def control_registers(self) -> Tuple[Register, ...]: ... - - def get_single_qubit_controlled_bloq( - self, control_val: int - ) -> 'SpecializedSingleQubitControlledExtension': - """Override this to provide a custom controlled bloq""" - return attrs.evolve(self, control_val=control_val) # type: ignore[misc] - - def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']: - if self.control_val is None and ctrl_spec.shapes in [((),), ((1,),)]: - control_val = int(ctrl_spec.cvs[0].item()) - cbloq = self.get_single_qubit_controlled_bloq(control_val) - - if not hasattr(cbloq, 'control_registers'): - raise TypeError("{cbloq} should have attribute `control_registers`") - - (ctrl_reg,) = cbloq.control_registers - - def adder( - bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: dict[str, 'SoquetT'] - ) -> tuple[Iterable['SoquetT'], Iterable['SoquetT']]: - soqs = {ctrl_reg.name: ctrl_soqs[0]} | in_soqs - soqs = bb.add_d(cbloq, **soqs) - ctrl_soqs = [soqs.pop(ctrl_reg.name)] - return ctrl_soqs, soqs.values() - - return cbloq, adder - - return super().get_ctrl_system(ctrl_spec) diff --git a/qualtran/bloqs/block_encoding/lcu_block_encoding.py b/qualtran/bloqs/block_encoding/lcu_block_encoding.py index 189290834..98af6a8f9 100644 --- a/qualtran/bloqs/block_encoding/lcu_block_encoding.py +++ b/qualtran/bloqs/block_encoding/lcu_block_encoding.py @@ -23,12 +23,12 @@ BloqBuilder, BloqDocSpec, CtrlSpec, + QBit, Register, Side, Signature, SoquetT, ) -from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension from qualtran.bloqs.block_encoding.block_encoding_base import BlockEncoding from qualtran.bloqs.multiplexers.black_box_select import BlackBoxSelect from qualtran.bloqs.multiplexers.select_base import SelectOracle @@ -45,7 +45,7 @@ def _total_bits(registers: Union[Tuple[Register, ...], Signature]) -> int: @attrs.frozen -class SelectBlockEncoding(BlockEncoding, SpecializedSingleQubitControlledExtension): +class SelectBlockEncoding(BlockEncoding): r"""LCU based block encoding using SELECT and PREPARE oracles. Builds the block encoding via @@ -96,11 +96,6 @@ class SelectBlockEncoding(BlockEncoding, SpecializedSingleQubitControlledExtensi select: Union[BlackBoxSelect, SelectOracle] prepare: Union[BlackBoxPrepare, PrepareOracle] - control_val: Optional[int] = None - - @cached_property - def control_registers(self) -> Tuple[Register, ...]: - return self.select.control_registers @cached_property def ancilla_bitsize(self) -> int: @@ -137,14 +132,7 @@ def epsilon(self) -> SymbolicFloat: @cached_property def signature(self) -> Signature: - return Signature( - [ - *self.control_registers, - *self.selection_registers, - *self.junk_registers, - *self.target_registers, - ] - ) + return Signature([*self.selection_registers, *self.junk_registers, *self.target_registers]) @cached_property def signal_state(self) -> Union[BlackBoxPrepare, PrepareOracle]: @@ -158,26 +146,12 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: SoquetT) -> Dict[str, def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol': if reg is None: return Text('') - if reg.name == 'control': - return Circle(filled=bool(self.control_val)) else: return TextBox('B[H]') - def get_single_qubit_controlled_bloq(self, control_val: int) -> 'SelectBlockEncoding': - if self.control_val is not None: - raise ValueError( - "control_val is not None but trying to build controlled SelectBlockEncoding." - ) - c_select = self.select.controlled(ctrl_spec=CtrlSpec(cvs=control_val)) - if not isinstance(c_select, SelectOracle): - raise TypeError( - f"controlled version of {self.select} = {c_select} must also be a SelectOracle" - ) - return attrs.evolve(self, select=c_select, control_val=control_val) - @attrs.frozen -class LCUBlockEncoding(BlockEncoding, SpecializedSingleQubitControlledExtension): +class LCUBlockEncoding(BlockEncoding): r"""LCU based block encoding using SELECT and PREPARE oracles. Builds the standard block encoding from an LCU as @@ -231,7 +205,7 @@ class LCUBlockEncoding(BlockEncoding, SpecializedSingleQubitControlledExtension) @cached_property def control_registers(self) -> Tuple[Register, ...]: - return self.select.control_registers + return () if self.control_val is None else (Register('ctrl', QBit()),) @cached_property def ancilla_bitsize(self) -> int: @@ -287,8 +261,18 @@ def _extract_soqs(bloq: Bloq) -> Dict[str, 'SoquetT']: return {reg.name: soqs.pop(reg.name) for reg in bloq.signature.lefts()} soqs |= bb.add_d(self.prepare, **_extract_soqs(self.prepare)) - soqs |= bb.add_d(self.select, **_extract_soqs(self.select)) + + select_soqs = _extract_soqs(self.select) + if self.control_val is None: + soqs |= bb.add_d(self.select, **select_soqs) + else: + _, ctrl_select_adder = self.select.get_ctrl_system(CtrlSpec(cvs=self.control_val)) + (ctrl,), select_soqs_t = ctrl_select_adder(bb, [soqs.pop('ctrl')], select_soqs) + soqs |= {'ctrl': ctrl} + soqs |= dict(zip([reg.name for reg in self.select.signature], select_soqs_t)) + soqs |= bb.add_d(self.prepare.adjoint(), **_extract_soqs(self.prepare.adjoint())) + return soqs def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol': @@ -299,17 +283,13 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) - else: return TextBox('B[H]') - def get_single_qubit_controlled_bloq(self, control_val: int) -> 'LCUBlockEncoding': - if self.control_val is not None: - raise ValueError( - "control_val is not None but trying to build controlled SelectBlockEncoding." - ) - c_select = self.select.controlled(ctrl_spec=CtrlSpec(cvs=control_val)) - if not isinstance(c_select, SelectOracle): - raise TypeError( - f"controlled version of {self.select} = {c_select} must also be a SelectOracle" - ) - return attrs.evolve(self, select=c_select, control_val=control_val) + def adjoint(self) -> 'Bloq': + from qualtran.bloqs.mcmt.specialized_ctrl import ( + AdjointWithSpecializedCtrl, + SpecializeOnCtrlBit, + ) + + return AdjointWithSpecializedCtrl(self, SpecializeOnCtrlBit.ONE) @bloq_example diff --git a/qualtran/bloqs/mean_estimation/mean_estimation_operator.py b/qualtran/bloqs/mean_estimation/mean_estimation_operator.py index b145b5cba..1e5b5d7db 100644 --- a/qualtran/bloqs/mean_estimation/mean_estimation_operator.py +++ b/qualtran/bloqs/mean_estimation/mean_estimation_operator.py @@ -13,20 +13,20 @@ # limitations under the License. from functools import cached_property -from typing import Iterator, Optional, Tuple +from typing import Iterator, Tuple, TYPE_CHECKING import attrs -import cirq from numpy.typing import NDArray -from qualtran import CtrlSpec, Register, Signature -from qualtran._infra.gate_with_registers import GateWithRegisters, total_bits -from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension +from qualtran import GateWithRegisters, Register, Signature from qualtran.bloqs.mean_estimation.complex_phase_oracle import ComplexPhaseOracle from qualtran.bloqs.multiplexers.select_base import SelectOracle from qualtran.bloqs.reflections.reflection_using_prepare import ReflectionUsingPrepare from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle +if TYPE_CHECKING: + import cirq + @attrs.frozen class CodeForRandomVariable: @@ -65,7 +65,7 @@ def __attrs_post_init__(self): @attrs.frozen -class MeanEstimationOperator(GateWithRegisters, SpecializedSingleQubitControlledExtension): # type: ignore[misc] +class MeanEstimationOperator(GateWithRegisters): r"""Mean estimation operator $U=REFL_{p} ROT_{y}$ as per Sec 3.1 of arxiv.org:2208.07544. The MeanEstimationOperator (aka KO Operator) expects `CodeForRandomVariable` to specify the @@ -84,51 +84,39 @@ class MeanEstimationOperator(GateWithRegisters, SpecializedSingleQubitControlled """ code: CodeForRandomVariable - control_val: Optional[int] = None arctan_bitsize: int = 32 @cached_property def reflect(self) -> ReflectionUsingPrepare: - return ReflectionUsingPrepare( - self.code.synthesizer, global_phase=-1, control_val=self.control_val - ) + return ReflectionUsingPrepare(self.code.synthesizer, global_phase=-1) @cached_property def select(self) -> ComplexPhaseOracle: return ComplexPhaseOracle(self.code.encoder, self.arctan_bitsize) - @cached_property - def control_registers(self) -> Tuple[Register, ...]: - return self.code.encoder.control_registers - @cached_property def selection_registers(self) -> Tuple[Register, ...]: return self.code.encoder.selection_registers @cached_property def signature(self) -> Signature: - return Signature([*self.control_registers, *self.selection_registers]) + return Signature([*self.selection_registers]) def decompose_from_registers( self, *, - context: cirq.DecompositionContext, - **quregs: NDArray[cirq.Qid], # type:ignore[type-var] - ) -> Iterator[cirq.OP_TREE]: + context: 'cirq.DecompositionContext', + **quregs: NDArray['cirq.Qid'], # type:ignore[type-var] + ) -> Iterator['cirq.OP_TREE']: select_reg = {reg.name: quregs[reg.name] for reg in self.select.signature} reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.signature} yield self.select.on_registers(**select_reg) yield self.reflect.on_registers(**reflect_reg) - def get_single_qubit_controlled_bloq(self, control_val: int) -> 'MeanEstimationOperator': - c_encoder = self.code.encoder.controlled(ctrl_spec=CtrlSpec(cvs=control_val)) - assert isinstance(c_encoder, SelectOracle) - c_code = attrs.evolve(self.code, encoder=c_encoder) - return attrs.evolve(self, code=c_code, control_val=control_val) - - def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = [] - if self.control_val is not None: - wire_symbols.append("@" if self.control_val == 1 else "(0)") - wire_symbols += ['U_ko'] * (total_bits(self.signature) - total_bits(self.control_registers)) + def _circuit_diagram_info_( + self, args: 'cirq.CircuitDiagramInfoArgs' + ) -> 'cirq.CircuitDiagramInfo': + import cirq + + wire_symbols = ['U_ko'] * self.signature.n_qubits() return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/qualtran/bloqs/mean_estimation/mean_estimation_operator_test.py b/qualtran/bloqs/mean_estimation/mean_estimation_operator_test.py index cc62f7f4a..bbc6d9ff3 100644 --- a/qualtran/bloqs/mean_estimation/mean_estimation_operator_test.py +++ b/qualtran/bloqs/mean_estimation/mean_estimation_operator_test.py @@ -20,9 +20,8 @@ import pytest from attrs import frozen -from qualtran import BQUInt, QAny, QBit, QUInt, Register +from qualtran import BQUInt, QAny, QUInt, Register from qualtran._infra.gate_with_registers import get_named_qubits, total_bits -from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension from qualtran.bloqs.mean_estimation.mean_estimation_operator import ( CodeForRandomVariable, MeanEstimationOperator, @@ -52,7 +51,7 @@ def decompose_from_registers( # type:ignore[override] @frozen -class BernoulliEncoder(SelectOracle, SpecializedSingleQubitControlledExtension): # type: ignore[misc] +class BernoulliEncoder(SelectOracle): r"""Encodes Bernoulli random variable y0/y1 as $Enc|ii..i>|0> = |ii..i>|y_{i}>$ where i=0/1.""" p: float @@ -63,7 +62,7 @@ class BernoulliEncoder(SelectOracle, SpecializedSingleQubitControlledExtension): @cached_property def control_registers(self) -> Tuple[Register, ...]: - return () if self.control_val is None else (Register('control', QBit()),) + return () @cached_property def selection_registers(self) -> Tuple[Register, ...]: diff --git a/qualtran/bloqs/qubitization/qubitization_walk_operator_test.py b/qualtran/bloqs/qubitization/qubitization_walk_operator_test.py index 8e6d4f0cc..69b8da1d5 100644 --- a/qualtran/bloqs/qubitization/qubitization_walk_operator_test.py +++ b/qualtran/bloqs/qubitization/qubitization_walk_operator_test.py @@ -17,7 +17,7 @@ import pytest from qualtran import Adjoint -from qualtran._infra.gate_with_registers import get_named_qubits +from qualtran.bloqs.basic_gates import Power, XGate from qualtran.bloqs.chemistry.ising.walk_operator import get_walk_operator_for_1d_ising_model from qualtran.bloqs.mcmt import MultiControlPauli from qualtran.bloqs.multiplexers.select_pauli_lcu import SelectPauliLCU @@ -107,7 +107,7 @@ def test_qubitization_walk_operator_diagrams(): num_sites, eps = 4, 1e-1 walk, _ = get_walk_operator_for_1d_ising_model(num_sites, eps) # 1. Diagram for $W = SELECT.R_{L}$ - g, qubit_order, walk_circuit = construct_gate_helper_and_qubit_order(walk, decompose_once=True) + walk_circuit = walk.decompose_bloq().to_cirq_circuit() cirq.testing.assert_has_diagram( walk_circuit, ''' @@ -128,14 +128,7 @@ def test_qubitization_walk_operator_diagrams(): ) # 2. Diagram for $W^{2} = B[H].R_{L}.B[H].R_{L}$ - def decompose_twice(op): - ops = [] - for sub_op in cirq.decompose_once(op): - ops += cirq.decompose_once(sub_op) - return ops - - walk_squared_op = (walk**2).on_registers(**g.quregs) - circuit = cirq.Circuit(decompose_twice(walk_squared_op)) + circuit = Power(walk, 2).decompose_bloq().flatten_once().to_cirq_circuit() cirq.testing.assert_has_diagram( circuit, ''' @@ -154,13 +147,14 @@ def decompose_twice(op): target3: ──────B[H]─────────B[H]───────── ''', ) + # 3. Diagram for $Ctrl-W = Ctrl-B[H].Ctrl-R_{L}$ - controlled_walk_op = walk.controlled().on_registers(**g.quregs, ctrl=cirq.q('control')) - circuit = cirq.Circuit(cirq.decompose_once(controlled_walk_op)) + controlled_walk_op = walk.controlled().decompose_bloq() + circuit = controlled_walk_op.to_cirq_circuit() cirq.testing.assert_has_diagram( circuit, ''' -control: ──────@──────@───── +ctrl: ─────────@──────@───── │ │ selection0: ───B[H]───R_L─── │ │ @@ -177,22 +171,17 @@ def decompose_twice(op): target3: ──────B[H]───────── ''', ) - # 4. Diagram for $Ctrl-W = Ctrl-SELECT.Ctrl-R_{L}$ in terms of $Ctrl-SELECT$ and $PREPARE$. - gateset_to_keep = cirq.Gateset( - SelectPauliLCU, StatePreparationAliasSampling, MultiControlPauli, cirq.X - ) - def keep(op): - ret = op in gateset_to_keep - if op.gate is not None and isinstance(op.gate, Adjoint): - ret |= op.gate.subbloq in gateset_to_keep - return ret + # 4. Diagram for $Ctrl-W = Ctrl-SELECT.Ctrl-R_{L}$ in terms of $Ctrl-SELECT$ and $PREPARE$. + def pred(binst): + bloqs_to_keep = (SelectPauliLCU, StatePreparationAliasSampling, MultiControlPauli, XGate) + keep = binst.bloq_is(bloqs_to_keep) + if binst.bloq_is(Adjoint): + keep |= isinstance(binst.bloq.subbloq, bloqs_to_keep) + return not keep greedy_mm = cirq.GreedyQubitManager(prefix="ancilla", maximize_reuse=True) - context = cirq.DecompositionContext(greedy_mm) - circuit = cirq.Circuit( - cirq.decompose(controlled_walk_op, keep=keep, on_stuck_raise=None, context=context) - ) + circuit = controlled_walk_op.flatten(pred=pred).to_cirq_circuit(qubit_manager=greedy_mm) # pylint: disable=line-too-long cirq.testing.assert_has_diagram( circuit, @@ -226,7 +215,7 @@ def keep(op): │ │ ancilla_13: ────────────────────less_than_equal──────────────────────────less_than_equal────────────────── │ │ -control: ──────@────────────────┼───────────────────────────────Z───────Z┼──────────────────────────────── +ctrl: ─────────@────────────────┼───────────────────────────────Z───────Z┼──────────────────────────────── │ │ │ │ selection0: ───In───────────────StatePreparationAliasSampling───@(0)─────StatePreparationAliasSampling──── │ │ │ │ @@ -247,24 +236,6 @@ def keep(op): # pylint: enable=line-too-long -def test_qubitization_walk_operator_consistent_protocols_and_controlled(): - gate, _ = get_walk_operator_for_1d_ising_model(4, 1e-1) - op = gate.on_registers(**get_named_qubits(gate.signature)) - # Build controlled gate - equals_tester = cirq.testing.EqualsTester() - equals_tester.add_equality_group( - gate.controlled(), - gate.controlled(num_controls=1), - gate.controlled(control_values=(1,)), - op.controlled_by(cirq.q("control")).gate, - ) - equals_tester.add_equality_group( - gate.controlled(control_values=(0,)), - gate.controlled(num_controls=1, control_values=(0,)), - op.controlled_by(cirq.q("control"), control_values=(0,)).gate, - ) - - @pytest.mark.notebook def test_notebook(): execute_notebook('qubitization_walk_operator')