diff --git a/qualtran/bloqs/factoring/ecc/ec_add.py b/qualtran/bloqs/factoring/ecc/ec_add.py index 5f8216777..572e47673 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add.py +++ b/qualtran/bloqs/factoring/ecc/ec_add.py @@ -41,12 +41,12 @@ CModNeg, CModSub, DirtyOutOfPlaceMontgomeryModMul, + KaliskiModInverse, ModAdd, ModDbl, ModNeg, ModSub, ) -from qualtran.bloqs.mod_arithmetic._shims import ModInv from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator from qualtran.simulation.classical_sim import ClassicalValT from qualtran.symbolics.types import HasLength, is_symbolic @@ -285,7 +285,7 @@ def build_composite_bloq( ctrl, b, y = bb.add(CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=b, y=y) # Perform modular inversion s.t. x = (x - a)^-1 % p. - x, z1, z2 = bb.add(ModInv(n=self.n, mod=self.mod), x=x) + x, junk = bb.add(KaliskiModInverse(bitsize=self.n, mod=self.mod), x=x) # Perform modular multiplication z4 = (y / x) % p. x, y, z4, z3, reduced = bb.add( @@ -336,7 +336,7 @@ def build_composite_bloq( qrom_indices=z3, reduced=reduced, ) - x = bb.add(ModInv(n=self.n, mod=self.mod).adjoint(), x=x, garbage1=z1, garbage2=z2) + x = bb.add(KaliskiModInverse(bitsize=self.n, mod=self.mod).adjoint(), x=x, junk=junk) # Return the output registers. return {'f1': f1, 'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam, 'lam_r': lam_r} @@ -346,7 +346,7 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: Equals(QMontgomeryUInt(self.n)): 1, ModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, - ModInv(n=self.n, mod=self.mod): 1, + KaliskiModInverse(bitsize=self.n, mod=self.mod): 1, DirtyOutOfPlaceMontgomeryModMul( bitsize=self.n, window_size=self.window_size, mod=self.mod ): 1, @@ -355,7 +355,7 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: DirtyOutOfPlaceMontgomeryModMul( bitsize=self.n, window_size=self.window_size, mod=self.mod ).adjoint(): 1, - ModInv(n=self.n, mod=self.mod).adjoint(): 1, + KaliskiModInverse(bitsize=self.n, mod=self.mod).adjoint(): 1, } @@ -706,7 +706,7 @@ def build_composite_bloq( raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") # x = x ^ -1 % p. - x, z1, z2 = bb.add(ModInv(n=self.n, mod=self.mod), x=x) + x, junk = bb.add(KaliskiModInverse(bitsize=self.n, mod=self.mod), x=x) # z4 = x * y % p. x, y, z4, z3, reduced = bb.add( @@ -744,7 +744,7 @@ def build_composite_bloq( qrom_indices=z3, reduced=reduced, ) - x = bb.add(ModInv(n=self.n, mod=self.mod).adjoint(), x=x, garbage1=z1, garbage2=z2) + x = bb.add(KaliskiModInverse(bitsize=self.n, mod=self.mod).adjoint(), x=x, junk=junk) # If ctrl: x = x_r - a % p. ctrl, x = bb.add(CModNeg(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=x) @@ -761,14 +761,14 @@ def build_composite_bloq( def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: return { CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, - ModInv(n=self.n, mod=self.mod): 1, + KaliskiModInverse(bitsize=self.n, mod=self.mod): 1, DirtyOutOfPlaceMontgomeryModMul( bitsize=self.n, window_size=self.window_size, mod=self.mod ): 1, DirtyOutOfPlaceMontgomeryModMul( bitsize=self.n, window_size=self.window_size, mod=self.mod ).adjoint(): 1, - ModInv(n=self.n, mod=self.mod).adjoint(): 1, + KaliskiModInverse(bitsize=self.n, mod=self.mod).adjoint(): 1, ModAdd(self.n, mod=self.mod): 1, MultiControlX(cvs=[1, 1]): self.n, CModNeg(QMontgomeryUInt(self.n), mod=self.mod): 1, diff --git a/qualtran/bloqs/factoring/ecc/ec_add_test.py b/qualtran/bloqs/factoring/ecc/ec_add_test.py index 5295316f1..37c397707 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add_test.py +++ b/qualtran/bloqs/factoring/ecc/ec_add_test.py @@ -407,19 +407,18 @@ def test_ec_add_symbolic_cost(): # This is why instead of using bitsize=n directly, we use bitsize=4*m=n. b = ECAdd(n=4 * m, window_size=4, mod=p) cost = get_cost_value(b, QECGatesCost()).total_t_and_ccz_count() - assert cost['n_t'] == 0 + # We have some T gates since we use CSwapApprox instead of n CSWAPs in KaliskiModInverse. + total_toff = (cost['n_t'] / 4 + cost['n_ccz']) * sympy.Integer(1) + total_toff = total_toff.subs(m, n / 4).expand() # Litinski 2023 https://arxiv.org/abs/2306.08585 # Based on the counts from Figures 3, 5, and 8 the toffoli count for ECAdd is 126.5n^2 + 189n. - # The following formula is 126.5n^2 + 175.5n - 35. We account for the discrepancy in the - # coefficient of n by a reduction in the toffoli cost of Montgomery ModMult, n extra toffolis - # in ModNeg, and 2n extra toffolis to do n 3-controlled toffolis in step 2. The expression is - # written with rationals because sympy comparison fails with floats. - assert isinstance(cost['n_ccz'], sympy.Expr) - assert ( - cost['n_ccz'].subs(m, n / 4).expand() - == sympy.Rational(253, 2) * n**2 + sympy.Rational(351, 2) * n - 35 - ) + # The following formula is 126.5n^2 + 195.5n - 31. We account for the discrepancy in the + # coefficient of n by a reduction in the toffoli cost of Montgomery ModMult, an increase in the + # toffoli cost for Kaliski Mod Inverse, n extra toffolis in ModNeg, 2n extra toffolis to do n + # 3-controlled toffolis in step 2. The expression is written with rationals because sympy + # comparison fails with floats. + assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(391, 2) * n - 31 def test_ec_add(bloq_autotester): diff --git a/qualtran/bloqs/mod_arithmetic/__init__.py b/qualtran/bloqs/mod_arithmetic/__init__.py index 0ddb3fc2b..774aa44d6 100644 --- a/qualtran/bloqs/mod_arithmetic/__init__.py +++ b/qualtran/bloqs/mod_arithmetic/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._shims import ModInv from .mod_addition import CModAdd, CModAddK, CtrlScaleModAdd, ModAdd, ModAddK from .mod_division import KaliskiModInverse from .mod_multiplication import CModMulK, DirtyOutOfPlaceMontgomeryModMul, ModDbl diff --git a/qualtran/bloqs/mod_arithmetic/_shims.py b/qualtran/bloqs/mod_arithmetic/_shims.py deleted file mode 100644 index 7242f6afa..000000000 --- a/qualtran/bloqs/mod_arithmetic/_shims.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""This module has a selection of minimally-implemented modular arithmetic primitives. - -These bloqs serve as the callees in the call graphs of the algorithms found -in `qualtran.bloq.factoring`. They are place-holders, so we don't have undefined symbols -and can still merge the high-level algorithms. These shims will be fleshed out -and moved to their final organizational location soon (written: 2024-05-06). -""" - - -from collections import defaultdict -from functools import cached_property -from typing import Dict, Optional, Tuple, TYPE_CHECKING - -import attrs -from attrs import frozen - -from qualtran import Bloq, QUInt, Register, Side, Signature -from qualtran.bloqs.arithmetic import AddK, Negate -from qualtran.bloqs.arithmetic._shims import CHalf, CSub, Lt, MultiCToffoli -from qualtran.bloqs.arithmetic.controlled_addition import CAdd -from qualtran.bloqs.basic_gates import CNOT, CSwap, Swap, Toffoli -from qualtran.bloqs.mod_arithmetic.mod_multiplication import ModDbl -from qualtran.drawing import Text, TextBox, WireSymbol -from qualtran.simulation.classical_sim import ClassicalValT - -if TYPE_CHECKING: - from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator - - -@frozen -class _ModInvInner(Bloq): - n: int - mod: int - - @cached_property - def signature(self) -> 'Signature': - return Signature([Register('x', QUInt(self.n)), Register('out', QUInt(self.n))]) - - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': - # This listing is based off of Haner 2023, fig 15. The order of operations - # matches the order in the figure - listing = [ - (MultiCToffoli(self.n + 1), 1), - (CNOT(), 1), - (Toffoli(), 1), - (MultiCToffoli(n=3), 1), - (CNOT(), 2), - (Lt(self.n), 1), - (CSwap(self.n), 2), - (CSub(self.n), 1), - (CAdd(QUInt(self.n)), 1), - (CNOT(), 2), - (ModDbl(QUInt(self.n), self.mod), 1), - (CHalf(self.n), 1), - (CSwap(self.n), 2), - (CNOT(), 1), - ] - # Since the listing is time-ordered and the call graph protocol expects - # unique bloq keys, we group counts by bloqs. - summer: Dict[Bloq, int] = defaultdict(lambda: 0) - for bloq, n in listing: - summer[bloq] += n - return summer - - def wire_symbol( - self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple() - ) -> 'WireSymbol': - if reg is None: - return Text("") - if reg.name == 'x': - return TextBox('x') - elif reg.name == 'out': - return TextBox('$x^{-1}$') - raise ValueError(f'Unrecognized register name {reg.name}') - - -@frozen -class ModInv(Bloq): - n: int - mod: int - uncompute: bool = False - - @cached_property - def signature(self) -> 'Signature': - side = Side.LEFT if self.uncompute else Side.RIGHT - return Signature( - [ - Register('x', QUInt(self.n)), - Register('garbage1', QUInt(self.n), side=side), - Register('garbage2', QUInt(self.n), side=side), - ] - ) - - def adjoint(self) -> 'ModInv': - return attrs.evolve(self, uncompute=self.uncompute ^ True) - - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': - # Roetteler - # return {(Toffoli(), 32 * self.n**2 * log2(self.n))} - return { - _ModInvInner(n=self.n, mod=self.mod): 2 * self.n, - Negate(QUInt(self.n)): 1, - AddK(self.n, k=self.mod): 1, - Swap(self.n): 1, - } - - def on_classical_vals( - self, - x: 'ClassicalValT', - garbage1: Optional['ClassicalValT'] = None, - garbage2: Optional['ClassicalValT'] = None, - ) -> Dict[str, ClassicalValT]: - # TODO(https://github.com/quantumlib/Qualtran/issues/1443): Hacky classical simulation just - # to confirm correctness of ECAdd circuit. - if self.uncompute: - assert garbage1 is not None - assert garbage2 is not None - return {'x': garbage1} - assert garbage1 is None - assert garbage2 is None - - # Store the original x in the garbage registers for the uncompute simulation. - garbage1 = x - garbage2 = x - - x = pow(int(x), self.mod - 2, mod=self.mod) * pow(2, 2 * self.n, self.mod) % self.mod - - return {'x': x, 'garbage1': garbage1, 'garbage2': garbage2} - - def wire_symbol( - self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple() - ) -> 'WireSymbol': - if reg is None: - return Text("") - if reg.name == 'x': - return TextBox('x') - elif reg.name == 'out': - return TextBox('$x^{-1}$') - elif reg.name == 'garbage1': - return TextBox('garbage1') - elif reg.name == 'garbage2': - return TextBox('garbage2') - raise ValueError(f'Unrecognized register name {reg.name}')