diff --git a/qualtran/bloqs/basic_gates/swap.py b/qualtran/bloqs/basic_gates/swap.py index dc8d27e24..e251333a3 100644 --- a/qualtran/bloqs/basic_gates/swap.py +++ b/qualtran/bloqs/basic_gates/swap.py @@ -105,19 +105,15 @@ def adjoint(self) -> 'Bloq': return self def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']: - if ctrl_spec != CtrlSpec(): - return super().get_ctrl_system(ctrl_spec=ctrl_spec) - - cswap = TwoBitCSwap() - - def adder( - bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: Dict[str, 'SoquetT'] - ) -> Tuple[Iterable['SoquetT'], Iterable['SoquetT']]: - (ctrl,) = ctrl_soqs - ctrl, x, y = bb.add(cswap, ctrl=ctrl, x=in_soqs['x'], y=in_soqs['y']) - return [ctrl], [x, y] + from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv_from_bloqs - return cswap, adder + return get_ctrl_system_1bit_cv_from_bloqs( + self, + ctrl_spec, + current_ctrl_bit=None, + bloq_with_ctrl=TwoBitCSwap(), + ctrl_reg_name='ctrl', + ) @bloq_example @@ -201,6 +197,13 @@ def wire_symbol(self, reg: Optional['Register'], idx: Tuple[int, ...] = ()) -> ' else: return TextBox('×') + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']: + from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv_from_bloqs + + return get_ctrl_system_1bit_cv_from_bloqs( + self, ctrl_spec, current_ctrl_bit=1, bloq_with_ctrl=self, ctrl_reg_name='ctrl' + ) + @bloq_example def _cswap_bit() -> TwoBitCSwap: diff --git a/qualtran/bloqs/basic_gates/y_gate.py b/qualtran/bloqs/basic_gates/y_gate.py index b418fc406..5212bb5ca 100644 --- a/qualtran/bloqs/basic_gates/y_gate.py +++ b/qualtran/bloqs/basic_gates/y_gate.py @@ -173,6 +173,13 @@ def wire_symbol( return TextBox('Y') raise ValueError(f"Unknown register {reg}.") + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']: + from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv_from_bloqs + + return get_ctrl_system_1bit_cv_from_bloqs( + self, ctrl_spec, current_ctrl_bit=1, bloq_with_ctrl=self, ctrl_reg_name='ctrl' + ) + @bloq_example def _cy_gate() -> CYGate: diff --git a/qualtran/bloqs/basic_gates/z_basis.py b/qualtran/bloqs/basic_gates/z_basis.py index 40bd02255..599acfc89 100644 --- a/qualtran/bloqs/basic_gates/z_basis.py +++ b/qualtran/bloqs/basic_gates/z_basis.py @@ -340,6 +340,13 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) - return Circle() raise ValueError(f'Unknown wire symbol register name: {reg.name}') + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']: + from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv_from_bloqs + + return get_ctrl_system_1bit_cv_from_bloqs( + self, ctrl_spec, current_ctrl_bit=1, bloq_with_ctrl=self, ctrl_reg_name='q1' + ) + @bloq_example def _cz() -> CZ: diff --git a/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard.py b/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard.py index f99a56971..66a123fa5 100644 --- a/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard.py +++ b/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard.py @@ -90,6 +90,10 @@ class SelectHubbard(SelectOracle): def __attrs_post_init__(self): if self.x_dim != self.y_dim: raise NotImplementedError("Currently only supports the case where x_dim=y_dim.") + if self.control_val == 0: + raise NotImplementedError( + "control_val=0 not supported, use `SelectHubbard(x, y).controlled(CtrlSpec(cvs=0))` instead" + ) @cached_property def control_registers(self) -> Tuple[Register, ...]: @@ -191,18 +195,24 @@ def __str__(self): return s def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']: - from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv + from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv_from_bloqs - return get_ctrl_system_1bit_cv( + return get_ctrl_system_1bit_cv_from_bloqs( self, ctrl_spec=ctrl_spec, current_ctrl_bit=self.control_val, - get_ctrl_bloq_and_ctrl_reg_name=lambda cv: ( - attrs.evolve(self, control_val=cv), - 'control', - ), + bloq_with_ctrl=attrs.evolve(self, control_val=1), + ctrl_reg_name='control', + ) + + def adjoint(self) -> 'Bloq': + from qualtran.bloqs.mcmt.specialized_ctrl import ( + AdjointWithSpecializedCtrl, + SpecializeOnCtrlBit, ) + return AdjointWithSpecializedCtrl(self, specialize_on_ctrl=SpecializeOnCtrlBit.ONE) + @bloq_example def _sel_hubb() -> SelectHubbard: diff --git a/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard_test.py b/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard_test.py index da6ce3deb..da964a804 100644 --- a/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard_test.py +++ b/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard_test.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import ANY import pytest @@ -18,7 +19,7 @@ _sel_hubb, SelectHubbard, ) -from qualtran.cirq_interop.t_complexity_protocol import t_complexity +from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost def test_sel_hubb_auto(bloq_autotester): @@ -28,8 +29,19 @@ def test_sel_hubb_auto(bloq_autotester): @pytest.mark.parametrize('dim', [*range(2, 10)]) def test_select_t_complexity(dim): select = SelectHubbard(x_dim=dim, y_dim=dim, control_val=1) - cost = t_complexity(select) + cost = get_cost_value(select, QECGatesCost()) N = 2 * dim * dim logN = 2 * (dim - 1).bit_length() + 1 - assert cost.t == 10 * N + 14 * logN - 8 - assert cost.rotations == 0 + assert cost == GateCounts( + cswap=2 * logN, and_bloq=5 * (N // 2) - 2, measurement=5 * (N // 2) - 2, clifford=ANY + ) + assert cost.total_t_count() == 10 * N + 14 * logN - 8 + + +def test_adjoint_controlled(): + bloq = _sel_hubb() + + adj_ctrl_bloq = bloq.controlled().adjoint() + ctrl_adj_bloq = bloq.adjoint().controlled() + + assert adj_ctrl_bloq == ctrl_adj_bloq diff --git a/qualtran/bloqs/mcmt/specialized_ctrl.ipynb b/qualtran/bloqs/mcmt/specialized_ctrl.ipynb new file mode 100644 index 000000000..67c4664b3 --- /dev/null +++ b/qualtran/bloqs/mcmt/specialized_ctrl.ipynb @@ -0,0 +1,201 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "## Bloqs with specialized controlled implementations\n", + "\n", + "In some cases, a bloq may have a specialized singly-controlled version (e.g. `LCUBlockEncoding`).\n", + "Qualtran provides a convenience methods `get_ctrl_system_1bit_cv` and `get_ctrl_system_1bit_cv_from_bloqs` to override the `get_ctrl_system`. These methods ensure that multiply-controlled bloqs are correctly reduced to the provided singly-controlled variants.\n", + "\n", + "- `get_ctrl_system_1bit_cv_from_bloqs` - Override when a specialized controlled-by-1 implementation is available.\n", + "- `get_ctrl_system_1bit_cv` - Override when both specialized controlled-by-1 and controlled-by-0 implementations are available.\n", + "\n", + "The following demonstrates an example for a bloq implementing $T^\\dagger X T$, where the controlled version only needs to control the $X$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import attrs\n", + "from qualtran import Bloq, BloqBuilder, Soquet, SoquetT, Signature, CtrlSpec, AddControlledT\n", + "from qualtran.bloqs.basic_gates import TGate, XGate, CNOT\n", + "\n", + "\n", + "@attrs.frozen\n", + "class BloqWithSpecializedCtrl(Bloq):\n", + " \"\"\"Bloq implementing $T^\\dagger X T$\"\"\"\n", + " is_controlled: bool = False\n", + "\n", + " @property\n", + " def signature(self) -> 'Signature':\n", + " n_ctrls = 1 if self.is_controlled else 0\n", + " return Signature.build(ctrl=n_ctrls, q=1)\n", + " \n", + " def build_composite_bloq(self, bb: 'BloqBuilder', q: 'Soquet', **soqs) -> dict[str, 'SoquetT']:\n", + " ctrl = soqs.pop('ctrl', None)\n", + " \n", + " q = bb.add(TGate(), q=q)\n", + " if self.is_controlled:\n", + " ctrl, q = bb.add(CNOT(), ctrl=ctrl, target=q)\n", + " else:\n", + " ctrl, q = bb.add(XGate(), ctrl=ctrl, target=q)\n", + " q = bb.add(TGate().adjoint(), q=q)\n", + " \n", + " out_soqs = {'q': q}\n", + " if ctrl:\n", + " out_soqs |= {'ctrl': ctrl}\n", + " return out_soqs\n", + " \n", + " def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']:\n", + " from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv_from_bloqs\n", + "\n", + " return get_ctrl_system_1bit_cv_from_bloqs(\n", + " self,\n", + " ctrl_spec,\n", + " current_ctrl_bit=1 if self.is_controlled else None,\n", + " bloq_with_ctrl=attrs.evolve(self, is_controlled=True),\n", + " ctrl_reg_name='ctrl',\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloq, show_call_graph\n", + "\n", + "bloq = BloqWithSpecializedCtrl().controlled().controlled()\n", + "show_bloq(bloq.decompose_bloq().flatten())\n", + "show_call_graph(bloq)" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## Propagating the Adjoint\n", + "\n", + "In the above bloq, calling controlled on the adjoint does not push the controls into the bloq, and therefore does not use the specialized implementation provided." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "BloqWithSpecializedCtrl().adjoint().controlled()" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "This can be fixed by overriding the adjoint using a special wrapper for this case - `AdjointWithSpecializedCtrl`. This is a subclass of the default `Adjoint` metabloq, and ensures that single-qubit controls are pushed into the underlying bloq." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "@attrs.frozen\n", + "class BloqWithSpecializedCtrlWithAdjoint(Bloq):\n", + " \"\"\"Bloq implementing $T^\\dagger X T$\"\"\"\n", + " is_controlled: bool = False\n", + "\n", + " @property\n", + " def signature(self) -> 'Signature':\n", + " n_ctrls = 1 if self.is_controlled else 0\n", + " return Signature.build(ctrl=n_ctrls, q=1)\n", + " \n", + " def build_composite_bloq(self, bb: 'BloqBuilder', q: 'Soquet', **soqs) -> dict[str, 'SoquetT']:\n", + " ctrl = soqs.pop('ctrl', None)\n", + " \n", + " q = bb.add(TGate(), q=q)\n", + " if self.is_controlled:\n", + " ctrl, q = bb.add(CNOT(), ctrl=ctrl, target=q)\n", + " else:\n", + " ctrl, q = bb.add(XGate(), ctrl=ctrl, target=q)\n", + " q = bb.add(TGate().adjoint(), q=q)\n", + " \n", + " out_soqs = {'q': q}\n", + " if ctrl:\n", + " out_soqs |= {'ctrl': ctrl}\n", + " return out_soqs\n", + " \n", + " def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']:\n", + " from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv_from_bloqs\n", + "\n", + " return get_ctrl_system_1bit_cv_from_bloqs(\n", + " self,\n", + " ctrl_spec,\n", + " current_ctrl_bit=1 if self.is_controlled else None,\n", + " bloq_with_ctrl=attrs.evolve(self, is_controlled=True),\n", + " ctrl_reg_name='ctrl',\n", + " )\n", + "\n", + " def adjoint(self):\n", + " from qualtran.bloqs.mcmt.specialized_ctrl import AdjointWithSpecializedCtrl, SpecializeOnCtrlBit\n", + " \n", + " return AdjointWithSpecializedCtrl(self, SpecializeOnCtrlBit.ONE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "BloqWithSpecializedCtrlWithAdjoint().adjoint().controlled()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "assert BloqWithSpecializedCtrlWithAdjoint().adjoint().controlled() == BloqWithSpecializedCtrlWithAdjoint(is_controlled=True).adjoint()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/qualtran/bloqs/mcmt/specialized_ctrl.py b/qualtran/bloqs/mcmt/specialized_ctrl.py index 8c70df3f6..ce1bd3362 100644 --- a/qualtran/bloqs/mcmt/specialized_ctrl.py +++ b/qualtran/bloqs/mcmt/specialized_ctrl.py @@ -11,16 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum from functools import cached_property from typing import Callable, cast, Iterable, Optional, Sequence, TYPE_CHECKING import attrs import numpy as np -from qualtran import Bloq, QBit, Register, Signature +from qualtran import Adjoint, Bloq, BloqBuilder, CompositeBloq, QBit, Register, Signature +from qualtran.bloqs.bookkeeping import AutoPartition if TYPE_CHECKING: - from qualtran import AddControlledT, BloqBuilder, CtrlSpec, SoquetT + from qualtran import AddControlledT, CtrlSpec, SoquetT from qualtran._infra.controlled import ControlBit from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator @@ -180,7 +182,12 @@ def _adder( return [ctrl0], [ctrl1, *out_soqs] - return ctrl_bloq, _adder + def _unwrap(b): + if isinstance(b, AutoPartition): + return _unwrap(b.bloq) + return b + + return _unwrap(ctrl_bloq), _adder def get_ctrl_system_1bit_cv( @@ -253,3 +260,120 @@ def get_ctrl_bloq_and_ctrl_reg_name(cv: 'ControlBit') -> Optional[tuple['Bloq', current_ctrl_bit=current_ctrl_bit, get_ctrl_bloq_and_ctrl_reg_name=get_ctrl_bloq_and_ctrl_reg_name, ) + + +class SpecializeOnCtrlBit(enum.Flag): + """Control-specs to propagate to the subbloq. + + See `AdjointWithSpecializedCtrl` for usage. + + Currently only allows pushing a single-qubit-control. + """ + + NONE = enum.auto() + ZERO = enum.auto() + ONE = enum.auto() + BOTH = ZERO | ONE + + +@attrs.frozen() +class AdjointWithSpecializedCtrl(Adjoint): + """Adjoint of a bloq with a specialized control implementation. + + If the subbloq has a specialized control implementation, then calling + `Adjoint(subbloq).controlled()` propagates the controls to the subbloq. + This only propagates single-qubit `CtrlSpec`s, all others use the default: + reduced to single-qubit control using the `ControlledViaAnd` bloq. + + By default in Qualtran, `Controlled(bloq).adjoint()` returns `Controlled(bloq.adjoint())`. + But `Adjoint(bloq).controlled()` does not propagate the controls, therefore returns + `Controlled(Adjoint(bloq))`. + This bloq helps override that behaviour for single-qubit controlled versions. + + For example, if a bloq has a specialized implementation for the controlled-by-1 case: + + ```py + class BloqWithSpecializedCtrl(Bloq): + ... + + def adjoint(self): + return AdjointWithSpecializedCtrl(self, SpecializeOnCtrlBit.ONE) + ``` + + See `get_ctrl_system_1bit_cv` on one way to provide specialized controlled implementations + for bloqs. If a bloq uses the above and does not have a trivial `adjoint` implementation, + it is recommended to override the `adjoint` method as show above. + + Caution: + Use this bloq _only_ when a specialized control implementation is guaranteed, + i.e. `subbloq.controlled()` should not return `Controlled(...)`. + Otherwise, it could lead to an infinite recursion. + + Args: + subbloq: The bloq to wrap. + specialize_on_ctrl: Values of the control bit to propagate the control into the subbloq. + Can be `SpecializeOnCtrlBit.ONE` for `1` only, `SpecializeOnCtrlBit.ZERO` for `0` only, + or `SpecializeOnCtrlBit.BOTH` for both `0` and `1`. + """ + + specialize_on_ctrl: SpecializeOnCtrlBit = SpecializeOnCtrlBit.NONE + + def _specialize_control(self, ctrl_spec: 'CtrlSpec') -> bool: + """if True, push the control to the subbloq""" + if ctrl_spec.num_qubits != 1: + return False + + cv = ctrl_spec.get_single_ctrl_bit() + cv_flag = SpecializeOnCtrlBit.ONE if cv == 1 else SpecializeOnCtrlBit.ZERO + return cv_flag in self.specialize_on_ctrl + + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']: + from qualtran._infra.controlled import _get_nice_ctrl_reg_names + + if not self._specialize_control(ctrl_spec): + # no specialized controlled version available, fallback to default + return super().get_ctrl_system(ctrl_spec) + + # get the builder for the controlled version of subbloq + ctrl_subbloq, ctrl_subbloq_adder = self.subbloq.get_ctrl_system(ctrl_spec) + ctrl_bloq = attrs.evolve(self, subbloq=ctrl_subbloq) + (ctrl_reg_name,) = _get_nice_ctrl_reg_names([reg.name for reg in self.subbloq.signature], 1) + + # build a composite bloq using the control-adder + def _get_adj_cbloq() -> 'CompositeBloq': + bb, initial_soqs = BloqBuilder.from_signature( + self.subbloq.signature, add_registers_allowed=True + ) + ctrl = bb.add_register(ctrl_reg_name, 1) + bb.add_register_allowed = False + + (ctrl,), out_soqs_t = ctrl_subbloq_adder(bb, [ctrl], initial_soqs) + + out_soqs = dict(zip([reg.name for reg in self.subbloq.signature.rights()], out_soqs_t)) + out_soqs |= {ctrl_reg_name: ctrl} + + cbloq = bb.finalize(**out_soqs) + return cbloq.adjoint() + + adj_cbloq = _get_adj_cbloq() + + def _adder( + bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: dict[str, 'SoquetT'] + ) -> tuple[Iterable['SoquetT'], Iterable['SoquetT']]: + (ctrl,) = ctrl_soqs + in_soqs |= {ctrl_reg_name: ctrl} + soqs = bb.add_from(adj_cbloq, **in_soqs) + + # locate the correct control soquet + soqs = list(soqs) + ctrl_soq = None + for soq, reg in zip(soqs, adj_cbloq.signature.rights()): + if reg.name == ctrl_reg_name: + ctrl_soq = soq + soqs.remove(soq) + break + assert ctrl_soq is not None, "ctrl_soq must be present in output soqs" + + return [ctrl_soq], soqs + + return ctrl_bloq, _adder diff --git a/qualtran/bloqs/mcmt/specialized_ctrl_test.py b/qualtran/bloqs/mcmt/specialized_ctrl_test.py index 01282c1f2..c42e95da7 100644 --- a/qualtran/bloqs/mcmt/specialized_ctrl_test.py +++ b/qualtran/bloqs/mcmt/specialized_ctrl_test.py @@ -18,6 +18,7 @@ import attrs import pytest +import qualtran.testing as qlt_testing from qualtran import ( AddControlledT, Bloq, @@ -31,12 +32,19 @@ ) from qualtran.bloqs.mcmt import And from qualtran.bloqs.mcmt.specialized_ctrl import ( + AdjointWithSpecializedCtrl, get_ctrl_system_1bit_cv, get_ctrl_system_1bit_cv_from_bloqs, + SpecializeOnCtrlBit, ) from qualtran.resource_counting import CostKey, GateCounts, get_cost_value, QECGatesCost +def _keep_and(b): + # TODO remove this after https://github.com/quantumlib/Qualtran/issues/1346 is resolved. + return isinstance(b, And) + + @attrs.frozen class AtomWithSpecializedControl(Bloq): cv: Optional[int] = None @@ -78,6 +86,9 @@ def my_static_costs(self, cost_key: 'CostKey'): return NotImplemented + def adjoint(self) -> 'AdjointWithSpecializedCtrl': + return AdjointWithSpecializedCtrl(self, specialize_on_ctrl=SpecializeOnCtrlBit.BOTH) + def ON(n: int = 1) -> CtrlSpec: return CtrlSpec(cvs=[1] * n) @@ -133,6 +144,9 @@ def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlled ctrl_reg_name='ctrl', ) + def adjoint(self) -> 'AdjointWithSpecializedCtrl': + return AdjointWithSpecializedCtrl(self, specialize_on_ctrl=SpecializeOnCtrlBit.ONE) + @attrs.frozen class CTestAtom(Bloq): @@ -147,14 +161,13 @@ def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlled self, ctrl_spec, current_ctrl_bit=1, bloq_with_ctrl=self, ctrl_reg_name='ctrl' ) + def adjoint(self) -> 'AdjointWithSpecializedCtrl': + return AdjointWithSpecializedCtrl(self, specialize_on_ctrl=SpecializeOnCtrlBit.ONE) + def test_bloq_with_controlled_bloq(): assert TestAtom('g').controlled() == CTestAtom('g') - def _keep_and(b): - # TODO remove this after https://github.com/quantumlib/Qualtran/issues/1346 is resolved. - return isinstance(b, And) - ctrl_bloq = CTestAtom('g').controlled() _, sigma = ctrl_bloq.call_graph(keep=_keep_and) assert sigma == {And(): 1, CTestAtom('g'): 1, And().adjoint(): 1} @@ -168,6 +181,27 @@ def _keep_and(b): assert sigma == {And(0, 0): 1, CTestAtom('nn'): 1, And(0, 0).adjoint(): 1} +def test_ctrl_adjoint(): + assert TestAtom('a').adjoint().controlled() == CTestAtom('a').adjoint() + + _, sigma = ( + TestAtom('g') + .adjoint() + .controlled(ctrl_spec=CtrlSpec(cvs=[1, 1])) + .call_graph(keep=_keep_and) + ) + assert sigma == {And(): 1, And().adjoint(): 1, CTestAtom('g').adjoint(): 1} + + _, sigma = CTestAtom('c').adjoint().controlled().call_graph(keep=_keep_and) + assert sigma == {And(): 1, And().adjoint(): 1, CTestAtom('c').adjoint(): 1} + + for cv in [0, 1]: + assert ( + AtomWithSpecializedControl().adjoint().controlled(ctrl_spec=CtrlSpec(cvs=cv)) + == AtomWithSpecializedControl(cv=cv).adjoint() + ) + + @attrs.frozen class TestBloqWithDecompose(Bloq): ctrl_reg_name: str @@ -201,3 +235,8 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> dict[str def test_get_ctrl_system(ctrl_reg_name: str, target_reg_name: str): bloq = TestBloqWithDecompose(ctrl_reg_name, target_reg_name).controlled() _ = bloq.decompose_bloq().flatten() + + +@pytest.mark.notebook +def test_notebook(): + qlt_testing.execute_notebook('specialized_ctrl') diff --git a/qualtran/bloqs/mod_arithmetic/mod_division.py b/qualtran/bloqs/mod_arithmetic/mod_division.py index 575a1aada..005323b96 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_division.py +++ b/qualtran/bloqs/mod_arithmetic/mod_division.py @@ -17,7 +17,7 @@ import numpy as np import sympy -from attrs import evolve, frozen +from attrs import evolve, field, frozen from qualtran import ( Bloq, @@ -65,16 +65,23 @@ def signature(self) -> 'Signature': Register('v', QMontgomeryUInt(self.bitsize)), Register('m', QBit()), Register('f', QBit()), + Register('is_terminal', QBit()), ] ) - def on_classical_vals(self, v: int, m: int, f: int) -> Dict[str, 'ClassicalValT']: + def on_classical_vals( + self, v: int, m: int, f: int, is_terminal: int + ) -> Dict[str, 'ClassicalValT']: + print('here') + assert False m ^= f & (v == 0) + assert is_terminal == 0 + is_terminal ^= m f ^= m - return {'v': v, 'm': m, 'f': f} + return {'v': v, 'm': m, 'f': f, 'is_terminal': is_terminal} def build_composite_bloq( - self, bb: 'BloqBuilder', v: Soquet, m: Soquet, f: Soquet + self, bb: 'BloqBuilder', v: Soquet, m: Soquet, f: Soquet, is_terminal: Soquet ) -> Dict[str, 'SoquetT']: if is_symbolic(self.bitsize): raise DecomposeTypeError(f'symbolic decomposition is not supported for {self}') @@ -89,7 +96,8 @@ def build_composite_bloq( f = ctrls[-1] v = bb.join(v_arr) m, f = bb.add(CNOT(), ctrl=m, target=f) - return {'v': v, 'm': m, 'f': f} + m, is_terminal = bb.add(CNOT(), ctrl=m, target=is_terminal) + return {'v': v, 'm': m, 'f': f, 'is_terminal': is_terminal} def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': if is_symbolic(self.bitsize): @@ -408,16 +416,27 @@ def signature(self) -> 'Signature': Register('s', QMontgomeryUInt(self.bitsize)), Register('m', QBit()), Register('f', QBit()), + Register('is_terminal', QBit()), ] ) def build_composite_bloq( - self, bb: 'BloqBuilder', u: Soquet, v: Soquet, r: Soquet, s: Soquet, m: Soquet, f: Soquet + self, + bb: 'BloqBuilder', + u: Soquet, + v: Soquet, + r: Soquet, + s: Soquet, + m: Soquet, + f: Soquet, + is_terminal: Soquet, ) -> Dict[str, 'SoquetT']: a = bb.allocate(1) b = bb.allocate(1) - v, m, f = bb.add(_KaliskiIterationStep1(self.bitsize), v=v, m=m, f=f) + v, m, f, is_terminal = bb.add( + _KaliskiIterationStep1(self.bitsize), v=v, m=m, f=f, is_terminal=is_terminal + ) u, v, b, a, m, f = bb.add( _KaliskiIterationStep2(self.bitsize), u=u, v=v, b=b, a=a, m=m, f=f ) @@ -434,7 +453,7 @@ def build_composite_bloq( bb.free(a) bb.free(b) - return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f} + return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f, 'is_terminal': is_terminal} def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': return { @@ -447,7 +466,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': } def on_classical_vals( - self, u: int, v: int, r: int, s: int, m: int, f: int + self, u: int, v: int, r: int, s: int, m: int, f: int, is_terminal: int ) -> Dict[str, 'ClassicalValT']: """This is the Kaliski algorithm as described in Fig7 of https://arxiv.org/pdf/2001.09580. @@ -456,6 +475,7 @@ def on_classical_vals( of `f` and `m`. """ assert m == 0 + is_terminal = f == 1 and v == 0 if f == 0: # When `f = 0` this means that the algorithm is nearly over and that we just need to # double the value of `r`. @@ -484,7 +504,7 @@ def on_classical_vals( if swap: u, v = v, u r, s = s, r - return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f} + return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f, 'is_terminal': is_terminal} @frozen @@ -504,6 +524,7 @@ def signature(self) -> 'Signature': Register('s', QMontgomeryUInt(self.bitsize)), Register('m', QAny(2 * self.bitsize)), Register('f', QBit()), + Register('terminal_condition', QAny(2 * self.bitsize)), ] ) @@ -512,17 +533,33 @@ def _kaliski_iteration(self): return _KaliskiIteration(self.bitsize, self.mod) def build_composite_bloq( - self, bb: 'BloqBuilder', u: Soquet, v: Soquet, r: Soquet, s: Soquet, m: Soquet, f: Soquet + self, + bb: 'BloqBuilder', + u: Soquet, + v: Soquet, + r: Soquet, + s: Soquet, + m: Soquet, + f: Soquet, + terminal_condition: Soquet, ) -> Dict[str, 'SoquetT']: f = bb.add(XGate(), q=f) u = bb.add(XorK(QMontgomeryUInt(self.bitsize), self.mod), x=u) s = bb.add(XorK(QMontgomeryUInt(self.bitsize), 1), x=s) m_arr = bb.split(m) + terminal_condition_arr = bb.split(terminal_condition) for i in range(2 * self.bitsize): - u, v, r, s, m_arr[i], f = bb.add( - self._kaliski_iteration, u=u, v=v, r=r, s=s, m=m_arr[i], f=f + u, v, r, s, m_arr[i], f, terminal_condition_arr[i] = bb.add( + self._kaliski_iteration, + u=u, + v=v, + r=r, + s=s, + m=m_arr[i], + f=f, + is_terminal=terminal_condition_arr[i], ) r = bb.add(BitwiseNot(QMontgomeryUInt(self.bitsize)), x=r) @@ -531,8 +568,43 @@ def build_composite_bloq( u = bb.add(XorK(QMontgomeryUInt(self.bitsize), 1), x=u) s = bb.add(XorK(QMontgomeryUInt(self.bitsize), self.mod), x=s) + # This is an extra step not present in the original Kaliski algorithm in order to + # handle the case of x=0. The invariant of the Kaliski algorithm is that that end of the + # algorithm u=1, s=0, r=mod inverse. This happens for all cases where the modular inverse + # exists (i.e. gcd(x, mod) = 1). + # The case where the input is zero is important. Although mathematically the inverse + # doesn't exist. For the bloq to be unitary it needs to map zero to itself. + # When the input is zero, the terminal values of the registers are r=mod, u=v=mod^1=mod-1 + # (assuming odd modulus). + # So we clean those registers conditioned on the first terminal qubit which is set + # if and only if the input is zero. + terminal_condition_arr[0], r = bb.add( + XorK(QMontgomeryUInt(self.bitsize), self.mod).controlled(), + ctrl=terminal_condition_arr[0], + x=r, + ) + terminal_condition_arr[0], u = bb.add( + XorK(QMontgomeryUInt(self.bitsize), self.mod - 1).controlled(), + ctrl=terminal_condition_arr[0], + x=u, + ) + terminal_condition_arr[0], s = bb.add( + XorK(QMontgomeryUInt(self.bitsize), self.mod - 1).controlled(), + ctrl=terminal_condition_arr[0], + x=s, + ) + m = bb.join(m_arr) - return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f} + terminal_condition = bb.join(terminal_condition_arr) + return { + 'u': u, + 'v': v, + 'r': r, + 's': s, + 'm': m, + 'f': f, + 'terminal_condition': terminal_condition, + } def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': return { @@ -542,6 +614,8 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': XGate(): 1, XorK(QMontgomeryUInt(self.bitsize), self.mod): 2, XorK(QMontgomeryUInt(self.bitsize), 1): 2, + XorK(QMontgomeryUInt(self.bitsize), self.mod).controlled(): 1, + XorK(QMontgomeryUInt(self.bitsize), self.mod - 1).controlled(): 2, } @@ -575,7 +649,7 @@ class KaliskiModInverse(Bloq): """ bitsize: 'SymbolicInt' - mod: 'SymbolicInt' + mod: 'SymbolicInt' = field(validator=lambda _, __, v: is_symbolic(v) or v % 2 == 1) uncompute: bool = False @cached_property @@ -584,12 +658,12 @@ def signature(self) -> 'Signature': return Signature( [ Register('x', QMontgomeryUInt(self.bitsize)), - Register('m', QAny(2 * self.bitsize), side=side), + Register('junk', QAny(4 * self.bitsize), side=side), ] ) def build_composite_bloq( - self, bb: 'BloqBuilder', x: Soquet, m: Optional[Soquet] = None, f: Optional[Soquet] = None + self, bb: 'BloqBuilder', x: Soquet, junk: Optional[Soquet] = None ) -> Dict[str, 'SoquetT']: u = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize)) r = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize)) @@ -597,9 +671,12 @@ def build_composite_bloq( f = bb.allocate(1) if self.uncompute: - assert m is not None - u, x, r, s, m, f = cast( - Tuple[Soquet, Soquet, Soquet, Soquet, Soquet, Soquet], + assert junk is not None + junk_arr = bb.split(junk) + m = bb.join(junk_arr[: 2 * self.bitsize]) + terminal_condition = bb.join(junk_arr[2 * self.bitsize :]) + u, x, r, s, m, f, terminal_condition = cast( + Tuple[Soquet, Soquet, Soquet, Soquet, Soquet, Soquet, Soquet], bb.add_from( _KaliskiModInverseImpl(self.bitsize, self.mod).adjoint(), u=u, @@ -608,6 +685,7 @@ def build_composite_bloq( s=s, m=m, f=f, + terminal_condition=terminal_condition, ), ) bb.free(u) @@ -615,22 +693,31 @@ def build_composite_bloq( bb.free(s) bb.free(m) bb.free(f) + bb.free(terminal_condition) return {'x': x} m = bb.allocate(2 * self.bitsize) - u, v, x, s, m, f = bb.add_from( - _KaliskiModInverseImpl(self.bitsize, self.mod), u=u, v=x, r=r, s=s, m=m, f=f + terminal_condition = bb.allocate(2 * self.bitsize) + u, v, x, s, m, f, terminal_condition = cast( + Tuple[Soquet, Soquet, Soquet, Soquet, Soquet, Soquet, Soquet], + bb.add_from( + _KaliskiModInverseImpl(self.bitsize, self.mod), + u=u, + v=x, + r=r, + s=s, + m=m, + f=f, + terminal_condition=terminal_condition, + ), ) - assert isinstance(u, Soquet) - assert isinstance(v, Soquet) - assert isinstance(s, Soquet) - assert isinstance(f, Soquet) bb.free(u) bb.free(v) bb.free(s) bb.free(f) - return {'x': x, 'm': m} + junk = bb.join(np.concatenate([bb.split(m), bb.split(terminal_condition)])) + return {'x': x, 'junk': junk} def adjoint(self) -> 'KaliskiModInverse': return evolve(self, uncompute=not self.uncompute) @@ -638,17 +725,25 @@ def adjoint(self) -> 'KaliskiModInverse': def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': return _KaliskiModInverseImpl(self.bitsize, self.mod).build_call_graph(ssa) - def on_classical_vals(self, x: int, m: int = 0) -> Dict[str, 'ClassicalValT']: - u, v, r, s, f = int(self.mod), x, 0, 1, 1 + def on_classical_vals(self, x: int, junk: int = 0) -> Dict[str, 'ClassicalValT']: + mod = int(self.mod) + u, v, r, s, f = mod, x, 0, 1, 1 + terminal_condition = m = 0 iteration = _KaliskiModInverseImpl(self.bitsize, self.mod)._kaliski_iteration for _ in range(2 * int(self.bitsize)): - u, v, r, s, m_i, f = iteration.call_classically(u=u, v=v, r=r, s=s, m=0, f=f) + u, v, r, s, m_i, f, is_terminal = iteration.call_classically( + u=u, v=v, r=r, s=s, m=0, f=f, is_terminal=0 + ) m = (m << 1) | m_i - assert u == 1 - assert s == self.mod + terminal_condition = (terminal_condition << 1) | is_terminal + assert u == 1 or (x == 0 and u == mod) + assert s == self.mod or (x == 0 and s == 1) assert f == 0 assert v == 0 - return {'x': self.mod - r, 'm': m} + return { + 'x': (self.mod - r) if r else 0, + 'junk': m * 2 ** (2 * self.bitsize) + terminal_condition, + } @bloq_example diff --git a/qualtran/bloqs/mod_arithmetic/mod_division_test.py b/qualtran/bloqs/mod_arithmetic/mod_division_test.py index 02646ef82..093f0908f 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_division_test.py +++ b/qualtran/bloqs/mod_arithmetic/mod_division_test.py @@ -36,11 +36,24 @@ def test_kaliski_mod_inverse_classical_action(bitsize, mod): continue x_montgomery = dtype.uint_to_montgomery(x, mod) res = blq.call_classically(x=x_montgomery) + print(x, x_montgomery) assert res == cblq.call_classically(x=x_montgomery) assert len(res) == 2 assert res[0] == dtype.montgomery_inverse(x_montgomery, mod) assert dtype.montgomery_product(int(res[0]), x_montgomery, mod) == R - assert blq.adjoint().call_classically(x=res[0], m=res[1]) == (x_montgomery,) + assert blq.adjoint().call_classically(x=res[0], junk=res[1]) == (x_montgomery,) + + +@pytest.mark.parametrize('bitsize', [5, 6]) +@pytest.mark.parametrize('mod', [3, 5, 7, 11, 13, 15]) +def test_kaliski_mod_inverse_classical_action_zero(bitsize, mod): + blq = KaliskiModInverse(bitsize, mod) + cblq = blq.decompose_bloq() + # When x = 0 the terminal condition is achieved at the first iteration, this corresponds to + # m_0 = is_terminal_0 = 1 and all other bits = 0. + junk = 2 ** (4 * bitsize - 1) + 2 ** (2 * bitsize - 1) + assert blq.call_classically(x=0) == cblq.call_classically(x=0) == (0, junk) + assert blq.adjoint().call_classically(x=0, junk=junk) == (0,) @pytest.mark.parametrize('bitsize', [5, 6]) diff --git a/qualtran/bloqs/multiplexers/select_pauli_lcu.py b/qualtran/bloqs/multiplexers/select_pauli_lcu.py index f94b890ac..7ca6b0d02 100644 --- a/qualtran/bloqs/multiplexers/select_pauli_lcu.py +++ b/qualtran/bloqs/multiplexers/select_pauli_lcu.py @@ -139,6 +139,9 @@ def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlled ), ) + def adjoint(self) -> 'Bloq': + return self + @bloq_example(generalizer=[cirq_to_bloqs, ignore_split_join, ignore_cliffords]) def _select_pauli_lcu() -> SelectPauliLCU: diff --git a/qualtran/bloqs/phase_estimation/phase_estimation_of_quantum_walk.ipynb b/qualtran/bloqs/phase_estimation/phase_estimation_of_quantum_walk.ipynb index 0268db0a8..6937510fa 100644 --- a/qualtran/bloqs/phase_estimation/phase_estimation_of_quantum_walk.ipynb +++ b/qualtran/bloqs/phase_estimation/phase_estimation_of_quantum_walk.ipynb @@ -92,7 +92,7 @@ " state_prep = cirq.StatePreparationChannel(get_resource_state(m), name='chi_m')\n", "\n", " yield state_prep.on(*m_qubits)\n", - " yield walk_controlled.on_registers(**walk_regs, control=m_qubits[0])\n", + " yield walk_controlled.on_registers(**walk_regs, ctrl=m_qubits[0])\n", " for i in range(1, m):\n", " yield reflect_controlled.on_registers(control=m_qubits[i], **reflect_regs)\n", " walk = walk ** 2\n", diff --git a/qualtran/bloqs/qubitization/qubitization_walk_operator.ipynb b/qualtran/bloqs/qubitization/qubitization_walk_operator.ipynb index 31ff0e00e..a31dafeb4 100644 --- a/qualtran/bloqs/qubitization/qubitization_walk_operator.ipynb +++ b/qualtran/bloqs/qubitization/qubitization_walk_operator.ipynb @@ -143,8 +143,7 @@ "\n", "#### Parameters\n", " - `select`: The SELECT lcu gate implementing $\\mathrm{SELECT}=\\sum_{l}|l\\rangle\\langle l|H_{l}$.\n", - " - `prepare`: Then PREPARE lcu gate implementing $\\mathrm{PREPARE}|0\\dots 0\\rangle = \\sum_l \\sqrt{\\frac{w_{l}}{\\lambda}} |l\\rangle = |L\\rangle$\n", - " - `control_val`: If 0/1, a controlled version of the walk operator is constructed. Defaults to None, in which case the resulting walk operator is not controlled. \n", + " - `prepare`: Then PREPARE lcu gate implementing $\\mathrm{PREPARE}|0\\dots 0\\rangle = \\sum_l \\sqrt{\\frac{w_{l}}{\\lambda}} |l\\rangle = |L\\rangle$ \n", "\n", "#### References\n", " - [Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity](https://arxiv.org/abs/1805.03662). Babbush et al. (2018). Figure 1.\n" diff --git a/qualtran/bloqs/qubitization/qubitization_walk_operator.py b/qualtran/bloqs/qubitization/qubitization_walk_operator.py index 0748a6748..9e270d8ec 100644 --- a/qualtran/bloqs/qubitization/qubitization_walk_operator.py +++ b/qualtran/bloqs/qubitization/qubitization_walk_operator.py @@ -27,16 +27,21 @@ """ from functools import cached_property -from typing import Iterator, Optional, Tuple, Union +from typing import Tuple, Union import attrs import cirq import numpy as np -from numpy.typing import NDArray -from qualtran import bloq_example, BloqDocSpec, CtrlSpec, Register, Signature -from qualtran._infra.gate_with_registers import GateWithRegisters, total_bits -from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension +from qualtran import ( + bloq_example, + BloqBuilder, + BloqDocSpec, + GateWithRegisters, + Register, + Signature, + SoquetT, +) from qualtran.bloqs.block_encoding.lcu_block_encoding import ( BlackBoxPrepare, LCUBlockEncoding, @@ -53,7 +58,7 @@ @attrs.frozen(cache_hash=True) -class QubitizationWalkOperator(GateWithRegisters, SpecializedSingleQubitControlledExtension): # type: ignore[misc] +class QubitizationWalkOperator(GateWithRegisters): r"""Construct a Szegedy Quantum Walk operator using LCU oracles SELECT and PREPARE. For a Hamiltonian $H = \sum_l w_l H_l$ (where coefficients $w_l > 0$ and $H_l$ are unitaries), @@ -79,8 +84,6 @@ class QubitizationWalkOperator(GateWithRegisters, SpecializedSingleQubitControll prepare: Then PREPARE lcu gate implementing $\mathrm{PREPARE}|0\dots 0\rangle = \sum_l \sqrt{\frac{w_{l}}{\lambda}} |l\rangle = |L\rangle$ - control_val: If 0/1, a controlled version of the walk operator is constructed. Defaults to - None, in which case the resulting walk operator is not controlled. References: [Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity](https://arxiv.org/abs/1805.03662). @@ -88,12 +91,6 @@ class QubitizationWalkOperator(GateWithRegisters, SpecializedSingleQubitControll """ block_encoding: Union[SelectBlockEncoding, LCUBlockEncoding] - control_val: Optional[int] = None - uncompute: bool = False - - @cached_property - def control_registers(self) -> Tuple[Register, ...]: - return self.block_encoding.control_registers @cached_property def selection_registers(self) -> Tuple[Register, ...]: @@ -119,58 +116,30 @@ def junk_registers(self) -> Tuple[Register, ...]: @cached_property def signature(self) -> Signature: - return Signature( - [ - *self.control_registers, - *self.selection_registers, - *self.target_registers, - *self.junk_registers, - ] - ) + return Signature([*self.selection_registers, *self.target_registers, *self.junk_registers]) @cached_property def reflect(self) -> ReflectionUsingPrepare: - return ReflectionUsingPrepare( - self.block_encoding.signal_state, control_val=self.control_val, global_phase=-1 - ) + return ReflectionUsingPrepare(self.block_encoding.signal_state, global_phase=-1) @cached_property def sum_of_lcu_coefficients(self) -> SymbolicFloat: r"""value of $\lambda$, i.e. sum of absolute values of coefficients $w_l$.""" return self.block_encoding.alpha - def decompose_from_registers( - self, - context: cirq.DecompositionContext, - **quregs: NDArray[cirq.Qid], # type:ignore[type-var] - ) -> Iterator[cirq.OP_TREE]: - select_reg = {reg.name: quregs[reg.name] for reg in self.block_encoding.signature} - - reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.signature} - if self.uncompute: - yield self.reflect.adjoint().on_registers(**reflect_reg) - yield self.block_encoding.adjoint().on_registers(**select_reg) - else: - yield self.block_encoding.on_registers(**select_reg) - yield self.reflect.on_registers(**reflect_reg) - - def get_single_qubit_controlled_bloq(self, control_val: int) -> 'QubitizationWalkOperator': - assert self.control_val is None - c_block = self.block_encoding.controlled(ctrl_spec=CtrlSpec(cvs=control_val)) - if not isinstance(c_block, (SelectBlockEncoding, LCUBlockEncoding)): - raise TypeError( - f"controlled version of {self.block_encoding} = {c_block} must also be a SelectOracle" - ) - return attrs.evolve(self, block_encoding=c_block, control_val=control_val) + def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> dict[str, 'SoquetT']: + be_soqs = {reg.name: soqs.pop(reg.name) for reg in self.block_encoding.signature} + soqs |= bb.add_d(self.block_encoding, **be_soqs) + + reflect_soqs = {reg.name: soqs.pop(reg.name) for reg in self.reflect.signature} + soqs |= bb.add_d(self.reflect, **reflect_soqs) + + return soqs def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ['@' if self.control_val else '@(0)'] * total_bits(self.control_registers) - wire_symbols += ['W'] * (total_bits(self.signature) - total_bits(self.control_registers)) + wire_symbols = ['W'] * self.signature.n_qubits() return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) - def adjoint(self) -> 'QubitizationWalkOperator': - return attrs.evolve(self, uncompute=not self.uncompute) - @cached_property def prepare(self) -> Union[PrepareOracle, BlackBoxPrepare]: """Get the Prepare bloq if appropriate from the block encoding.""" diff --git a/qualtran/bloqs/qubitization/qubitization_walk_operator_test.py b/qualtran/bloqs/qubitization/qubitization_walk_operator_test.py index 37019cf30..8e6d4f0cc 100644 --- a/qualtran/bloqs/qubitization/qubitization_walk_operator_test.py +++ b/qualtran/bloqs/qubitization/qubitization_walk_operator_test.py @@ -155,7 +155,7 @@ def decompose_twice(op): ''', ) # 3. Diagram for $Ctrl-W = Ctrl-B[H].Ctrl-R_{L}$ - controlled_walk_op = walk.controlled().on_registers(**g.quregs, control=cirq.q('control')) + controlled_walk_op = walk.controlled().on_registers(**g.quregs, ctrl=cirq.q('control')) circuit = cirq.Circuit(cirq.decompose_once(controlled_walk_op)) cirq.testing.assert_has_diagram( circuit,