Skip to content

Commit

Permalink
cleanup remaining uses of old control specialization interface
Browse files Browse the repository at this point in the history
  • Loading branch information
anurudhp committed Nov 7, 2024
1 parent 2b26082 commit f0b0dc5
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 199 deletions.
78 changes: 0 additions & 78 deletions qualtran/_infra/single_qubit_controlled.py

This file was deleted.

66 changes: 23 additions & 43 deletions qualtran/bloqs/block_encoding/lcu_block_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
BloqBuilder,
BloqDocSpec,
CtrlSpec,
QBit,
Register,
Side,
Signature,
SoquetT,
)
from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension
from qualtran.bloqs.block_encoding.block_encoding_base import BlockEncoding
from qualtran.bloqs.multiplexers.black_box_select import BlackBoxSelect
from qualtran.bloqs.multiplexers.select_base import SelectOracle
Expand All @@ -45,7 +45,7 @@ def _total_bits(registers: Union[Tuple[Register, ...], Signature]) -> int:


@attrs.frozen
class SelectBlockEncoding(BlockEncoding, SpecializedSingleQubitControlledExtension):
class SelectBlockEncoding(BlockEncoding):
r"""LCU based block encoding using SELECT and PREPARE oracles.
Builds the block encoding via
Expand Down Expand Up @@ -96,11 +96,6 @@ class SelectBlockEncoding(BlockEncoding, SpecializedSingleQubitControlledExtensi

select: Union[BlackBoxSelect, SelectOracle]
prepare: Union[BlackBoxPrepare, PrepareOracle]
control_val: Optional[int] = None

@cached_property
def control_registers(self) -> Tuple[Register, ...]:
return self.select.control_registers

@cached_property
def ancilla_bitsize(self) -> int:
Expand Down Expand Up @@ -137,14 +132,7 @@ def epsilon(self) -> SymbolicFloat:

@cached_property
def signature(self) -> Signature:
return Signature(
[
*self.control_registers,
*self.selection_registers,
*self.junk_registers,
*self.target_registers,
]
)
return Signature([*self.selection_registers, *self.junk_registers, *self.target_registers])

@cached_property
def signal_state(self) -> Union[BlackBoxPrepare, PrepareOracle]:
Expand All @@ -158,26 +146,12 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: SoquetT) -> Dict[str,
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
if reg is None:
return Text('')
if reg.name == 'control':
return Circle(filled=bool(self.control_val))
else:
return TextBox('B[H]')

def get_single_qubit_controlled_bloq(self, control_val: int) -> 'SelectBlockEncoding':
if self.control_val is not None:
raise ValueError(
"control_val is not None but trying to build controlled SelectBlockEncoding."
)
c_select = self.select.controlled(ctrl_spec=CtrlSpec(cvs=control_val))
if not isinstance(c_select, SelectOracle):
raise TypeError(
f"controlled version of {self.select} = {c_select} must also be a SelectOracle"
)
return attrs.evolve(self, select=c_select, control_val=control_val)


@attrs.frozen
class LCUBlockEncoding(BlockEncoding, SpecializedSingleQubitControlledExtension):
class LCUBlockEncoding(BlockEncoding):
r"""LCU based block encoding using SELECT and PREPARE oracles.
Builds the standard block encoding from an LCU as
Expand Down Expand Up @@ -231,7 +205,7 @@ class LCUBlockEncoding(BlockEncoding, SpecializedSingleQubitControlledExtension)

@cached_property
def control_registers(self) -> Tuple[Register, ...]:
return self.select.control_registers
return () if self.control_val is None else (Register('ctrl', QBit()),)

@cached_property
def ancilla_bitsize(self) -> int:
Expand Down Expand Up @@ -287,8 +261,18 @@ def _extract_soqs(bloq: Bloq) -> Dict[str, 'SoquetT']:
return {reg.name: soqs.pop(reg.name) for reg in bloq.signature.lefts()}

soqs |= bb.add_d(self.prepare, **_extract_soqs(self.prepare))
soqs |= bb.add_d(self.select, **_extract_soqs(self.select))

select_soqs = _extract_soqs(self.select)
if self.control_val is None:
soqs |= bb.add_d(self.select, **select_soqs)
else:
_, ctrl_select_adder = self.select.get_ctrl_system(CtrlSpec(cvs=self.control_val))
(ctrl,), select_soqs_t = ctrl_select_adder(bb, [soqs.pop('ctrl')], select_soqs)
soqs |= {'ctrl': ctrl}
soqs |= dict(zip([reg.name for reg in self.select.signature], select_soqs_t))

soqs |= bb.add_d(self.prepare.adjoint(), **_extract_soqs(self.prepare.adjoint()))

return soqs

def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
Expand All @@ -299,17 +283,13 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
else:
return TextBox('B[H]')

def get_single_qubit_controlled_bloq(self, control_val: int) -> 'LCUBlockEncoding':
if self.control_val is not None:
raise ValueError(
"control_val is not None but trying to build controlled SelectBlockEncoding."
)
c_select = self.select.controlled(ctrl_spec=CtrlSpec(cvs=control_val))
if not isinstance(c_select, SelectOracle):
raise TypeError(
f"controlled version of {self.select} = {c_select} must also be a SelectOracle"
)
return attrs.evolve(self, select=c_select, control_val=control_val)
def adjoint(self) -> 'Bloq':
from qualtran.bloqs.mcmt.specialized_ctrl import (
AdjointWithSpecializedCtrl,
SpecializeOnCtrlBit,
)

return AdjointWithSpecializedCtrl(self, SpecializeOnCtrlBit.ONE)


@bloq_example
Expand Down
46 changes: 17 additions & 29 deletions qualtran/bloqs/mean_estimation/mean_estimation_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@
# limitations under the License.

from functools import cached_property
from typing import Iterator, Optional, Tuple
from typing import Iterator, Tuple, TYPE_CHECKING

import attrs
import cirq
from numpy.typing import NDArray

from qualtran import CtrlSpec, Register, Signature
from qualtran._infra.gate_with_registers import GateWithRegisters, total_bits
from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension
from qualtran import GateWithRegisters, Register, Signature
from qualtran.bloqs.mean_estimation.complex_phase_oracle import ComplexPhaseOracle
from qualtran.bloqs.multiplexers.select_base import SelectOracle
from qualtran.bloqs.reflections.reflection_using_prepare import ReflectionUsingPrepare
from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle

if TYPE_CHECKING:
import cirq


@attrs.frozen
class CodeForRandomVariable:
Expand Down Expand Up @@ -65,7 +65,7 @@ def __attrs_post_init__(self):


@attrs.frozen
class MeanEstimationOperator(GateWithRegisters, SpecializedSingleQubitControlledExtension): # type: ignore[misc]
class MeanEstimationOperator(GateWithRegisters):
r"""Mean estimation operator $U=REFL_{p} ROT_{y}$ as per Sec 3.1 of arxiv.org:2208.07544.
The MeanEstimationOperator (aka KO Operator) expects `CodeForRandomVariable` to specify the
Expand All @@ -84,51 +84,39 @@ class MeanEstimationOperator(GateWithRegisters, SpecializedSingleQubitControlled
"""

code: CodeForRandomVariable
control_val: Optional[int] = None
arctan_bitsize: int = 32

@cached_property
def reflect(self) -> ReflectionUsingPrepare:
return ReflectionUsingPrepare(
self.code.synthesizer, global_phase=-1, control_val=self.control_val
)
return ReflectionUsingPrepare(self.code.synthesizer, global_phase=-1)

@cached_property
def select(self) -> ComplexPhaseOracle:
return ComplexPhaseOracle(self.code.encoder, self.arctan_bitsize)

@cached_property
def control_registers(self) -> Tuple[Register, ...]:
return self.code.encoder.control_registers

@cached_property
def selection_registers(self) -> Tuple[Register, ...]:
return self.code.encoder.selection_registers

@cached_property
def signature(self) -> Signature:
return Signature([*self.control_registers, *self.selection_registers])
return Signature([*self.selection_registers])

def decompose_from_registers(
self,
*,
context: cirq.DecompositionContext,
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
) -> Iterator[cirq.OP_TREE]:
context: 'cirq.DecompositionContext',
**quregs: NDArray['cirq.Qid'], # type:ignore[type-var]
) -> Iterator['cirq.OP_TREE']:
select_reg = {reg.name: quregs[reg.name] for reg in self.select.signature}
reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.signature}
yield self.select.on_registers(**select_reg)
yield self.reflect.on_registers(**reflect_reg)

def get_single_qubit_controlled_bloq(self, control_val: int) -> 'MeanEstimationOperator':
c_encoder = self.code.encoder.controlled(ctrl_spec=CtrlSpec(cvs=control_val))
assert isinstance(c_encoder, SelectOracle)
c_code = attrs.evolve(self.code, encoder=c_encoder)
return attrs.evolve(self, code=c_code, control_val=control_val)

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
wire_symbols = []
if self.control_val is not None:
wire_symbols.append("@" if self.control_val == 1 else "(0)")
wire_symbols += ['U_ko'] * (total_bits(self.signature) - total_bits(self.control_registers))
def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
import cirq

wire_symbols = ['U_ko'] * self.signature.n_qubits()
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
import pytest
from attrs import frozen

from qualtran import BQUInt, QAny, QBit, QUInt, Register
from qualtran import BQUInt, QAny, QUInt, Register
from qualtran._infra.gate_with_registers import get_named_qubits, total_bits
from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension
from qualtran.bloqs.mean_estimation.mean_estimation_operator import (
CodeForRandomVariable,
MeanEstimationOperator,
Expand Down Expand Up @@ -52,7 +51,7 @@ def decompose_from_registers( # type:ignore[override]


@frozen
class BernoulliEncoder(SelectOracle, SpecializedSingleQubitControlledExtension): # type: ignore[misc]
class BernoulliEncoder(SelectOracle):
r"""Encodes Bernoulli random variable y0/y1 as $Enc|ii..i>|0> = |ii..i>|y_{i}>$ where i=0/1."""

p: float
Expand All @@ -63,7 +62,7 @@ class BernoulliEncoder(SelectOracle, SpecializedSingleQubitControlledExtension):

@cached_property
def control_registers(self) -> Tuple[Register, ...]:
return () if self.control_val is None else (Register('control', QBit()),)
return ()

@cached_property
def selection_registers(self) -> Tuple[Register, ...]:
Expand Down
Loading

0 comments on commit f0b0dc5

Please sign in to comment.