Skip to content

Commit f0b0dc5

Browse files
committed
cleanup remaining uses of old control specialization interface
1 parent 2b26082 commit f0b0dc5

File tree

5 files changed

+59
-199
lines changed

5 files changed

+59
-199
lines changed

qualtran/_infra/single_qubit_controlled.py

Lines changed: 0 additions & 78 deletions
This file was deleted.

qualtran/bloqs/block_encoding/lcu_block_encoding.py

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
BloqBuilder,
2424
BloqDocSpec,
2525
CtrlSpec,
26+
QBit,
2627
Register,
2728
Side,
2829
Signature,
2930
SoquetT,
3031
)
31-
from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension
3232
from qualtran.bloqs.block_encoding.block_encoding_base import BlockEncoding
3333
from qualtran.bloqs.multiplexers.black_box_select import BlackBoxSelect
3434
from qualtran.bloqs.multiplexers.select_base import SelectOracle
@@ -45,7 +45,7 @@ def _total_bits(registers: Union[Tuple[Register, ...], Signature]) -> int:
4545

4646

4747
@attrs.frozen
48-
class SelectBlockEncoding(BlockEncoding, SpecializedSingleQubitControlledExtension):
48+
class SelectBlockEncoding(BlockEncoding):
4949
r"""LCU based block encoding using SELECT and PREPARE oracles.
5050
5151
Builds the block encoding via
@@ -96,11 +96,6 @@ class SelectBlockEncoding(BlockEncoding, SpecializedSingleQubitControlledExtensi
9696

9797
select: Union[BlackBoxSelect, SelectOracle]
9898
prepare: Union[BlackBoxPrepare, PrepareOracle]
99-
control_val: Optional[int] = None
100-
101-
@cached_property
102-
def control_registers(self) -> Tuple[Register, ...]:
103-
return self.select.control_registers
10499

105100
@cached_property
106101
def ancilla_bitsize(self) -> int:
@@ -137,14 +132,7 @@ def epsilon(self) -> SymbolicFloat:
137132

138133
@cached_property
139134
def signature(self) -> Signature:
140-
return Signature(
141-
[
142-
*self.control_registers,
143-
*self.selection_registers,
144-
*self.junk_registers,
145-
*self.target_registers,
146-
]
147-
)
135+
return Signature([*self.selection_registers, *self.junk_registers, *self.target_registers])
148136

149137
@cached_property
150138
def signal_state(self) -> Union[BlackBoxPrepare, PrepareOracle]:
@@ -158,26 +146,12 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: SoquetT) -> Dict[str,
158146
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
159147
if reg is None:
160148
return Text('')
161-
if reg.name == 'control':
162-
return Circle(filled=bool(self.control_val))
163149
else:
164150
return TextBox('B[H]')
165151

166-
def get_single_qubit_controlled_bloq(self, control_val: int) -> 'SelectBlockEncoding':
167-
if self.control_val is not None:
168-
raise ValueError(
169-
"control_val is not None but trying to build controlled SelectBlockEncoding."
170-
)
171-
c_select = self.select.controlled(ctrl_spec=CtrlSpec(cvs=control_val))
172-
if not isinstance(c_select, SelectOracle):
173-
raise TypeError(
174-
f"controlled version of {self.select} = {c_select} must also be a SelectOracle"
175-
)
176-
return attrs.evolve(self, select=c_select, control_val=control_val)
177-
178152

179153
@attrs.frozen
180-
class LCUBlockEncoding(BlockEncoding, SpecializedSingleQubitControlledExtension):
154+
class LCUBlockEncoding(BlockEncoding):
181155
r"""LCU based block encoding using SELECT and PREPARE oracles.
182156
183157
Builds the standard block encoding from an LCU as
@@ -231,7 +205,7 @@ class LCUBlockEncoding(BlockEncoding, SpecializedSingleQubitControlledExtension)
231205

232206
@cached_property
233207
def control_registers(self) -> Tuple[Register, ...]:
234-
return self.select.control_registers
208+
return () if self.control_val is None else (Register('ctrl', QBit()),)
235209

236210
@cached_property
237211
def ancilla_bitsize(self) -> int:
@@ -287,8 +261,18 @@ def _extract_soqs(bloq: Bloq) -> Dict[str, 'SoquetT']:
287261
return {reg.name: soqs.pop(reg.name) for reg in bloq.signature.lefts()}
288262

289263
soqs |= bb.add_d(self.prepare, **_extract_soqs(self.prepare))
290-
soqs |= bb.add_d(self.select, **_extract_soqs(self.select))
264+
265+
select_soqs = _extract_soqs(self.select)
266+
if self.control_val is None:
267+
soqs |= bb.add_d(self.select, **select_soqs)
268+
else:
269+
_, ctrl_select_adder = self.select.get_ctrl_system(CtrlSpec(cvs=self.control_val))
270+
(ctrl,), select_soqs_t = ctrl_select_adder(bb, [soqs.pop('ctrl')], select_soqs)
271+
soqs |= {'ctrl': ctrl}
272+
soqs |= dict(zip([reg.name for reg in self.select.signature], select_soqs_t))
273+
291274
soqs |= bb.add_d(self.prepare.adjoint(), **_extract_soqs(self.prepare.adjoint()))
275+
292276
return soqs
293277

294278
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
@@ -299,17 +283,13 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
299283
else:
300284
return TextBox('B[H]')
301285

302-
def get_single_qubit_controlled_bloq(self, control_val: int) -> 'LCUBlockEncoding':
303-
if self.control_val is not None:
304-
raise ValueError(
305-
"control_val is not None but trying to build controlled SelectBlockEncoding."
306-
)
307-
c_select = self.select.controlled(ctrl_spec=CtrlSpec(cvs=control_val))
308-
if not isinstance(c_select, SelectOracle):
309-
raise TypeError(
310-
f"controlled version of {self.select} = {c_select} must also be a SelectOracle"
311-
)
312-
return attrs.evolve(self, select=c_select, control_val=control_val)
286+
def adjoint(self) -> 'Bloq':
287+
from qualtran.bloqs.mcmt.specialized_ctrl import (
288+
AdjointWithSpecializedCtrl,
289+
SpecializeOnCtrlBit,
290+
)
291+
292+
return AdjointWithSpecializedCtrl(self, SpecializeOnCtrlBit.ONE)
313293

314294

315295
@bloq_example

qualtran/bloqs/mean_estimation/mean_estimation_operator.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,20 @@
1313
# limitations under the License.
1414

1515
from functools import cached_property
16-
from typing import Iterator, Optional, Tuple
16+
from typing import Iterator, Tuple, TYPE_CHECKING
1717

1818
import attrs
19-
import cirq
2019
from numpy.typing import NDArray
2120

22-
from qualtran import CtrlSpec, Register, Signature
23-
from qualtran._infra.gate_with_registers import GateWithRegisters, total_bits
24-
from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension
21+
from qualtran import GateWithRegisters, Register, Signature
2522
from qualtran.bloqs.mean_estimation.complex_phase_oracle import ComplexPhaseOracle
2623
from qualtran.bloqs.multiplexers.select_base import SelectOracle
2724
from qualtran.bloqs.reflections.reflection_using_prepare import ReflectionUsingPrepare
2825
from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle
2926

27+
if TYPE_CHECKING:
28+
import cirq
29+
3030

3131
@attrs.frozen
3232
class CodeForRandomVariable:
@@ -65,7 +65,7 @@ def __attrs_post_init__(self):
6565

6666

6767
@attrs.frozen
68-
class MeanEstimationOperator(GateWithRegisters, SpecializedSingleQubitControlledExtension): # type: ignore[misc]
68+
class MeanEstimationOperator(GateWithRegisters):
6969
r"""Mean estimation operator $U=REFL_{p} ROT_{y}$ as per Sec 3.1 of arxiv.org:2208.07544.
7070
7171
The MeanEstimationOperator (aka KO Operator) expects `CodeForRandomVariable` to specify the
@@ -84,51 +84,39 @@ class MeanEstimationOperator(GateWithRegisters, SpecializedSingleQubitControlled
8484
"""
8585

8686
code: CodeForRandomVariable
87-
control_val: Optional[int] = None
8887
arctan_bitsize: int = 32
8988

9089
@cached_property
9190
def reflect(self) -> ReflectionUsingPrepare:
92-
return ReflectionUsingPrepare(
93-
self.code.synthesizer, global_phase=-1, control_val=self.control_val
94-
)
91+
return ReflectionUsingPrepare(self.code.synthesizer, global_phase=-1)
9592

9693
@cached_property
9794
def select(self) -> ComplexPhaseOracle:
9895
return ComplexPhaseOracle(self.code.encoder, self.arctan_bitsize)
9996

100-
@cached_property
101-
def control_registers(self) -> Tuple[Register, ...]:
102-
return self.code.encoder.control_registers
103-
10497
@cached_property
10598
def selection_registers(self) -> Tuple[Register, ...]:
10699
return self.code.encoder.selection_registers
107100

108101
@cached_property
109102
def signature(self) -> Signature:
110-
return Signature([*self.control_registers, *self.selection_registers])
103+
return Signature([*self.selection_registers])
111104

112105
def decompose_from_registers(
113106
self,
114107
*,
115-
context: cirq.DecompositionContext,
116-
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
117-
) -> Iterator[cirq.OP_TREE]:
108+
context: 'cirq.DecompositionContext',
109+
**quregs: NDArray['cirq.Qid'], # type:ignore[type-var]
110+
) -> Iterator['cirq.OP_TREE']:
118111
select_reg = {reg.name: quregs[reg.name] for reg in self.select.signature}
119112
reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.signature}
120113
yield self.select.on_registers(**select_reg)
121114
yield self.reflect.on_registers(**reflect_reg)
122115

123-
def get_single_qubit_controlled_bloq(self, control_val: int) -> 'MeanEstimationOperator':
124-
c_encoder = self.code.encoder.controlled(ctrl_spec=CtrlSpec(cvs=control_val))
125-
assert isinstance(c_encoder, SelectOracle)
126-
c_code = attrs.evolve(self.code, encoder=c_encoder)
127-
return attrs.evolve(self, code=c_code, control_val=control_val)
128-
129-
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
130-
wire_symbols = []
131-
if self.control_val is not None:
132-
wire_symbols.append("@" if self.control_val == 1 else "(0)")
133-
wire_symbols += ['U_ko'] * (total_bits(self.signature) - total_bits(self.control_registers))
116+
def _circuit_diagram_info_(
117+
self, args: 'cirq.CircuitDiagramInfoArgs'
118+
) -> 'cirq.CircuitDiagramInfo':
119+
import cirq
120+
121+
wire_symbols = ['U_ko'] * self.signature.n_qubits()
134122
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

qualtran/bloqs/mean_estimation/mean_estimation_operator_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
import pytest
2121
from attrs import frozen
2222

23-
from qualtran import BQUInt, QAny, QBit, QUInt, Register
23+
from qualtran import BQUInt, QAny, QUInt, Register
2424
from qualtran._infra.gate_with_registers import get_named_qubits, total_bits
25-
from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension
2625
from qualtran.bloqs.mean_estimation.mean_estimation_operator import (
2726
CodeForRandomVariable,
2827
MeanEstimationOperator,
@@ -52,7 +51,7 @@ def decompose_from_registers( # type:ignore[override]
5251

5352

5453
@frozen
55-
class BernoulliEncoder(SelectOracle, SpecializedSingleQubitControlledExtension): # type: ignore[misc]
54+
class BernoulliEncoder(SelectOracle):
5655
r"""Encodes Bernoulli random variable y0/y1 as $Enc|ii..i>|0> = |ii..i>|y_{i}>$ where i=0/1."""
5756

5857
p: float
@@ -63,7 +62,7 @@ class BernoulliEncoder(SelectOracle, SpecializedSingleQubitControlledExtension):
6362

6463
@cached_property
6564
def control_registers(self) -> Tuple[Register, ...]:
66-
return () if self.control_val is None else (Register('control', QBit()),)
65+
return ()
6766

6867
@cached_property
6968
def selection_registers(self) -> Tuple[Register, ...]:

0 commit comments

Comments
 (0)