From e3aeee0d5e8f72304b2c8a1124196954e764ce0a Mon Sep 17 00:00:00 2001 From: Frankie Papa Date: Wed, 23 Oct 2024 14:27:30 -0700 Subject: [PATCH] Add ECAdd() Bloq (#1425) * Initial commit of ec add waiting on equals to be merged. * Working on tests for ECAdd * ECAdd implementation and tests * remove modmul typo * Fix mypy errors * Better bugfix for ModAdd * Change mod inv classical impl to use monttgomery inv * Fix pytest error * Fix some comments * ECAdd lots of testing * Add comments about bugs to be fixed * Reduce complexity by keeping intermediate values mod p * Stash qmontgomery tests * Address comments * Fix montgomery prod/inv calculations + pylint/mypy --------- Co-authored-by: Noureldin Co-authored-by: Matthew Harrigan --- qualtran/_infra/data_types.py | 44 +- qualtran/_infra/data_types_test.py | 26 + qualtran/bloqs/arithmetic/_shims.py | 22 +- qualtran/bloqs/factoring/ecc/ec_add.ipynb | 36 +- qualtran/bloqs/factoring/ecc/ec_add.py | 1075 ++++++++++++++++- qualtran/bloqs/factoring/ecc/ec_add_test.py | 412 ++++++- qualtran/bloqs/factoring/ecc/ec_point.py | 3 +- qualtran/bloqs/factoring/ecc/ec_point_test.py | 1 + qualtran/bloqs/mod_arithmetic/_shims.py | 55 +- .../mod_arithmetic/mod_multiplication.py | 3 +- qualtran/serialization/resolver_dict.py | 8 + 11 files changed, 1639 insertions(+), 46 deletions(-) diff --git a/qualtran/_infra/data_types.py b/qualtran/_infra/data_types.py index e21eb209f..ee918d506 100644 --- a/qualtran/_infra/data_types.py +++ b/qualtran/_infra/data_types.py @@ -772,9 +772,14 @@ class QMontgomeryUInt(QDType): bitsize: The number of qubits used to represent the integer. References: - [Montgomery modular multiplication](https://en.wikipedia.org/wiki/Montgomery_modular_multiplication) + [Montgomery modular multiplication](https://en.wikipedia.org/wiki/Montgomery_modular_multiplication). + + [Performance Analysis of a Repetition Cat Code Architecture: Computing 256-bit Elliptic Curve Logarithm in 9 Hours with 126133 Cat Qubits](https://arxiv.org/abs/2302.06639). + Gouzien et al. 2023. + We follow Montgomery form as described in the above paper; namely, r = 2^bitsize. """ + # TODO(https://github.com/quantumlib/Qualtran/issues/1471): Add modulus p as a class member. bitsize: SymbolicInt @property @@ -810,6 +815,43 @@ def assert_valid_classical_val_array( if np.any(val_array >= 2**self.bitsize): raise ValueError(f"Too-large classical values encountered in {debug_str}") + def montgomery_inverse(self, xm: int, p: int) -> int: + """Returns the modular inverse of an integer in montgomery form. + + Args: + xm: An integer in montgomery form. + p: The modulus of the finite field. + """ + return ((pow(xm, -1, p)) * pow(2, 2 * self.bitsize, p)) % p + + def montgomery_product(self, xm: int, ym: int, p: int) -> int: + """Returns the modular product of two integers in montgomery form. + + Args: + xm: The first montgomery form integer for the product. + ym: The second montgomery form integer for the product. + p: The modulus of the finite field. + """ + return (xm * ym * pow(2, -self.bitsize, p)) % p + + def montgomery_to_uint(self, xm: int, p: int) -> int: + """Converts an integer in montgomery form to a normal form integer. + + Args: + xm: An integer in montgomery form. + p: The modulus of the finite field. + """ + return (xm * pow(2, -self.bitsize, p)) % p + + def uint_to_montgomery(self, x: int, p: int) -> int: + """Converts an integer into montgomery form. + + Args: + x: An integer. + p: The modulus of the finite field. + """ + return (x * pow(2, int(self.bitsize), p)) % p + @attrs.frozen class QGF(QDType): diff --git a/qualtran/_infra/data_types_test.py b/qualtran/_infra/data_types_test.py index 10347b702..65252c5cb 100644 --- a/qualtran/_infra/data_types_test.py +++ b/qualtran/_infra/data_types_test.py @@ -135,6 +135,32 @@ def test_qmontgomeryuint(): assert is_symbolic(QMontgomeryUInt(sympy.Symbol('x'))) +@pytest.mark.parametrize('p', [13, 17, 29]) +@pytest.mark.parametrize('val', [1, 5, 7, 9]) +def test_qmontgomeryuint_operations(val, p): + qmontgomeryuint_8 = QMontgomeryUInt(8) + # Convert value to montgomery form and get the modular inverse. + val_m = qmontgomeryuint_8.uint_to_montgomery(val, p) + mod_inv = qmontgomeryuint_8.montgomery_inverse(val_m, p) + + # Calculate the product in montgomery form and convert back to normal form for assertion. + assert ( + qmontgomeryuint_8.montgomery_to_uint( + qmontgomeryuint_8.montgomery_product(val_m, mod_inv, p), p + ) + == 1 + ) + + +@pytest.mark.parametrize('p', [13, 17, 29]) +@pytest.mark.parametrize('val', [1, 5, 7, 9]) +def test_qmontgomeryuint_conversions(val, p): + qmontgomeryuint_8 = QMontgomeryUInt(8) + assert val == qmontgomeryuint_8.montgomery_to_uint( + qmontgomeryuint_8.uint_to_montgomery(val, p), p + ) + + def test_qgf(): qgf_256 = QGF(characteristic=2, degree=8) assert str(qgf_256) == 'QGF(2**8)' diff --git a/qualtran/bloqs/arithmetic/_shims.py b/qualtran/bloqs/arithmetic/_shims.py index 0d40daba7..d75e3d72f 100644 --- a/qualtran/bloqs/arithmetic/_shims.py +++ b/qualtran/bloqs/arithmetic/_shims.py @@ -22,8 +22,11 @@ from attrs import frozen -from qualtran import Bloq, QBit, QUInt, Register, Signature +from qualtran import Bloq, QBit, QMontgomeryUInt, QUInt, Register, Signature +from qualtran.bloqs.arithmetic.bitwise import BitwiseNot +from qualtran.bloqs.arithmetic.controlled_addition import CAdd from qualtran.bloqs.basic_gates import Toffoli +from qualtran.bloqs.basic_gates.swap import TwoBitCSwap from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator @@ -39,6 +42,20 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: return {Toffoli(): self.n - 2} +@frozen +class CSub(Bloq): + n: int + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [Register('ctrl', QBit()), Register('x', QUInt(self.n)), Register('y', QUInt(self.n))] + ) + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return {CAdd(QMontgomeryUInt(self.n)): 1, BitwiseNot(QMontgomeryUInt(self.n)): 3} + + @frozen class Lt(Bloq): n: int @@ -62,3 +79,6 @@ class CHalf(Bloq): @cached_property def signature(self) -> 'Signature': return Signature([Register('ctrl', QBit()), Register('x', QUInt(self.n))]) + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return {TwoBitCSwap(): self.n} diff --git a/qualtran/bloqs/factoring/ecc/ec_add.ipynb b/qualtran/bloqs/factoring/ecc/ec_add.ipynb index 543458c8c..cbc279bbe 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add.ipynb +++ b/qualtran/bloqs/factoring/ecc/ec_add.ipynb @@ -41,16 +41,22 @@ "This takes elliptic curve points given by (a, b) and (x, y)\n", "and outputs the sum (x_r, y_r) in the second pair of registers.\n", "\n", + "Because the decomposition of this Bloq is complex, we split it into six separate parts\n", + "corresponding to the parts described in figure 10 of the Litinski paper cited below. We follow\n", + "the signature from figure 5 and break down the further decompositions based on the steps in\n", + "figure 10.\n", + "\n", "#### Parameters\n", " - `n`: The bitsize of the two registers storing the elliptic curve point\n", - " - `mod`: The modulus of the field in which we do the addition. \n", + " - `mod`: The modulus of the field in which we do the addition.\n", + " - `window_size`: The number of bits in the ModMult window. \n", "\n", "#### Registers\n", - " - `a`: The x component of the first input elliptic curve point of bitsize `n`.\n", - " - `b`: The y component of the first input elliptic curve point of bitsize `n`.\n", - " - `x`: The x component of the second input elliptic curve point of bitsize `n`, which will contain the x component of the resultant curve point.\n", - " - `y`: The y component of the second input elliptic curve point of bitsize `n`, which will contain the y component of the resultant curve point.\n", - " - `lam`: The precomputed lambda slope used in the addition operation. \n", + " - `a`: The x component of the first input elliptic curve point of bitsize `n` in montgomery form.\n", + " - `b`: The y component of the first input elliptic curve point of bitsize `n` in montgomery form.\n", + " - `x`: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which will contain the x component of the resultant curve point.\n", + " - `y`: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which will contain the y component of the resultant curve point.\n", + " - `lam_r`: The precomputed lambda slope used in the addition operation if (a, b) = (x, y) in montgomery form. \n", "\n", "#### References\n", " - [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585). Litinski. 2023. Fig 5.\n" @@ -91,6 +97,18 @@ "ec_add = ECAdd(n, mod=p)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "170da165", + "metadata": { + "cq.autogen": "ECAdd.ec_add_small" + }, + "outputs": [], + "source": [ + "ec_add_small = ECAdd(5, mod=7)" + ] + }, { "cell_type": "markdown", "id": "39210af4", @@ -111,8 +129,8 @@ "outputs": [], "source": [ "from qualtran.drawing import show_bloqs\n", - "show_bloqs([ec_add],\n", - " ['`ec_add`'])" + "show_bloqs([ec_add, ec_add_small],\n", + " ['`ec_add`', '`ec_add_small`'])" ] }, { @@ -157,7 +175,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/qualtran/bloqs/factoring/ecc/ec_add.py b/qualtran/bloqs/factoring/ecc/ec_add.py index 74a57706a..5f8216777 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add.py +++ b/qualtran/bloqs/factoring/ecc/ec_add.py @@ -12,12 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import cached_property +from typing import Dict, Union +import numpy as np import sympy from attrs import frozen -from qualtran import Bloq, bloq_example, BloqDocSpec, QUInt, Register, Signature -from qualtran.bloqs.arithmetic._shims import MultiCToffoli +from qualtran import ( + Bloq, + bloq_example, + BloqBuilder, + BloqDocSpec, + DecomposeTypeError, + QBit, + QMontgomeryUInt, + Register, + Side, + Signature, + Soquet, + SoquetT, +) +from qualtran.bloqs.arithmetic.comparison import Equals +from qualtran.bloqs.basic_gates import CNOT, IntState, Toffoli, ZeroState +from qualtran.bloqs.bookkeeping import Free +from qualtran.bloqs.mcmt import MultiAnd, MultiControlX, MultiTargetCNOT from qualtran.bloqs.mod_arithmetic import ( CModAdd, CModNeg, @@ -30,6 +48,925 @@ ) from qualtran.bloqs.mod_arithmetic._shims import ModInv from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator +from qualtran.simulation.classical_sim import ClassicalValT +from qualtran.symbolics.types import HasLength, is_symbolic + +from .ec_point import ECPoint + + +@frozen +class _ECAddStepOne(Bloq): + r"""Performs step one of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + + Registers: + f1: Flag to set if a = x. + f2: Flag to set if b = -y. + f3: Flag to set if (a, b) = (0, 0). + f4: Flag to set if (x, y) = (0, 0). + ctrl: Flag to set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('f1', QBit(), side=Side.RIGHT), + Register('f2', QBit(), side=Side.RIGHT), + Register('f3', QBit(), side=Side.RIGHT), + Register('f4', QBit(), side=Side.RIGHT), + Register('ctrl', QBit(), side=Side.RIGHT), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, a: 'ClassicalValT', b: 'ClassicalValT', x: 'ClassicalValT', y: 'ClassicalValT' + ) -> Dict[str, 'ClassicalValT']: + f1 = int(a == x) + f2 = int(b == (-y % self.mod)) + f3 = int(a == b == 0) + f4 = int(x == y == 0) + ctrl = int(f2 == f3 == f4 == 0) + return { + 'f1': f1, + 'f2': f2, + 'f3': f3, + 'f4': f4, + 'ctrl': ctrl, + 'a': a, + 'b': b, + 'x': x, + 'y': y, + } + + def build_composite_bloq( + self, bb: 'BloqBuilder', a: Soquet, b: Soquet, x: Soquet, y: Soquet + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Initialize control flags to 0. + f1 = bb.add(ZeroState()) + f2 = bb.add(ZeroState()) + f3 = bb.add(ZeroState()) + f4 = bb.add(ZeroState()) + ctrl = bb.add(ZeroState()) + + # Set flag 1 if a = x. + a, x, f1 = bb.add(Equals(QMontgomeryUInt(self.n)), x=a, y=x, target=f1) + + # Set flag 2 if b = -y. + y = bb.add(ModNeg(QMontgomeryUInt(self.n), mod=self.mod), x=y) + b, y, f2 = bb.add(Equals(QMontgomeryUInt(self.n)), x=b, y=y, target=f2) + y = bb.add(ModNeg(QMontgomeryUInt(self.n), mod=self.mod), x=y) + + # Set flag 3 if (a, b) == (0, 0). + ab_arr = np.concatenate([bb.split(a), bb.split(b)]) + ab_arr, f3 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=ab_arr, target=f3) + ab_arr = np.split(ab_arr, 2) + a = bb.join(ab_arr[0], dtype=QMontgomeryUInt(self.n)) + b = bb.join(ab_arr[1], dtype=QMontgomeryUInt(self.n)) + + # Set flag 4 if (x, y) == (0, 0). + xy_arr = np.concatenate([bb.split(x), bb.split(y)]) + xy_arr, f4 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=xy_arr, target=f4) + xy_arr = np.split(xy_arr, 2) + x = bb.join(xy_arr[0], dtype=QMontgomeryUInt(self.n)) + y = bb.join(xy_arr[1], dtype=QMontgomeryUInt(self.n)) + + # Set ctrl flag if f2, f3, f4 are set. + f_ctrls = [f2, f3, f4] + f_ctrls, ctrl = bb.add(MultiControlX(cvs=[0] * 3), controls=f_ctrls, target=ctrl) + f2 = f_ctrls[0] + f3 = f_ctrls[1] + f4 = f_ctrls[2] + + # Return the output registers. + return { + 'f1': f1, + 'f2': f2, + 'f3': f3, + 'f4': f4, + 'ctrl': ctrl, + 'a': a, + 'b': b, + 'x': x, + 'y': y, + } + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + cvs: Union[list[int], HasLength] + if isinstance(self.n, int): + cvs = [0] * 2 * self.n + else: + cvs = HasLength(2 * self.n) + return { + Equals(QMontgomeryUInt(self.n)): 2, + ModNeg(QMontgomeryUInt(self.n), mod=self.mod): 2, + MultiControlX(cvs=cvs): 2, + MultiControlX(cvs=[0] * 3): 1, + } + + +@frozen +class _ECAddStepTwo(Bloq): + r"""Performs step two of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. + + Registers: + f1: Flag set if a = x. + ctrl: Flag set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + lam: The lambda slope used in the addition operation. + lam_r: The precomputed lambda slope used in the addition operation if (a, b) = (x, y) in montgomery form. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + window_size: int = 1 + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('f1', QBit()), + Register('ctrl', QBit()), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam', QMontgomeryUInt(self.n), side=Side.RIGHT), + Register('lam_r', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, + f1: 'ClassicalValT', + ctrl: 'ClassicalValT', + a: 'ClassicalValT', + b: 'ClassicalValT', + x: 'ClassicalValT', + y: 'ClassicalValT', + lam_r: 'ClassicalValT', + ) -> Dict[str, 'ClassicalValT']: + x = (x - a) % self.mod + if ctrl == 1: + y = (y - b) % self.mod + if f1 == 1: + lam = lam_r + f1 = 0 + else: + lam = QMontgomeryUInt(self.n).montgomery_product( + int(y), QMontgomeryUInt(self.n).montgomery_inverse(int(x), self.mod), self.mod + ) + # TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit + # which flips f1 when lam and lam_r are equal. + if lam == lam_r: + f1 = (f1 + 1) % 2 + else: + lam = 0 + return {'f1': f1, 'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam, 'lam_r': lam_r} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + f1: Soquet, + ctrl: Soquet, + a: Soquet, + b: Soquet, + x: Soquet, + y: Soquet, + lam_r: Soquet, + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Initalize lambda to 0. + lam = bb.add(IntState(bitsize=self.n, val=0)) + + # Perform modular subtraction so that x = (x - a) % p. + a, x = bb.add(ModSub(QMontgomeryUInt(self.n), mod=self.mod), x=a, y=x) + + # Perform controlled modular subtraction so that y = (y - b) % p iff ctrl = 1. + ctrl, b, y = bb.add(CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=b, y=y) + + # Perform modular inversion s.t. x = (x - a)^-1 % p. + x, z1, z2 = bb.add(ModInv(n=self.n, mod=self.mod), x=x) + + # Perform modular multiplication z4 = (y / x) % p. + x, y, z4, z3, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=x, + y=y, + ) + + # If ctrl = 1 and x != a: lam = (y - b) / (x - a) % p. + z4_split = bb.split(z4) + lam_split = bb.split(lam) + for i in range(self.n): + ctrls = [f1, ctrl, z4_split[i]] + ctrls, lam_split[i] = bb.add( + MultiControlX(cvs=[0, 1, 1]), controls=ctrls, target=lam_split[i] + ) + f1 = ctrls[0] + ctrl = ctrls[1] + z4_split[i] = ctrls[2] + z4 = bb.join(z4_split, dtype=QMontgomeryUInt(self.n)) + + # If ctrl = 1 and x = a: lam = lam_r. + lam_r_split = bb.split(lam_r) + for i in range(self.n): + ctrls = [f1, ctrl, lam_r_split[i]] + ctrls, lam_split[i] = bb.add( + MultiControlX(cvs=[1, 1, 1]), controls=ctrls, target=lam_split[i] + ) + f1 = ctrls[0] + ctrl = ctrls[1] + lam_r_split[i] = ctrls[2] + lam_r = bb.join(lam_r_split, dtype=QMontgomeryUInt(self.n)) + lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n)) + + # If lam = lam_r: return f1 = 0. (If not we will flip f1 to 0 at the end iff x_r = y_r = 0). + lam, lam_r, f1 = bb.add(Equals(QMontgomeryUInt(self.n)), x=lam, y=lam_r, target=f1) + + # Uncompute the modular multiplication then the modular inversion. + x, y = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=x, + y=y, + target=z4, + qrom_indices=z3, + reduced=reduced, + ) + x = bb.add(ModInv(n=self.n, mod=self.mod).adjoint(), x=x, garbage1=z1, garbage2=z2) + + # Return the output registers. + return {'f1': f1, 'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam, 'lam_r': lam_r} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return { + Equals(QMontgomeryUInt(self.n)): 1, + ModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + ModInv(n=self.n, mod=self.mod): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ): 1, + MultiControlX(cvs=[0, 1, 1]): self.n, + MultiControlX(cvs=[1, 1, 1]): self.n, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(): 1, + ModInv(n=self.n, mod=self.mod).adjoint(): 1, + } + + +@frozen +class _ECAddStepThree(Bloq): + r"""Performs step three of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. + + Registers: + ctrl: Flag set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + lam: The lambda slope used in the addition operation. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + window_size: int = 1 + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('ctrl', QBit()), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, + ctrl: 'ClassicalValT', + a: 'ClassicalValT', + b: 'ClassicalValT', + x: 'ClassicalValT', + y: 'ClassicalValT', + lam: 'ClassicalValT', + ) -> Dict[str, 'ClassicalValT']: + if ctrl == 1: + x = (x + 3 * a) % self.mod + y = 0 + return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + ctrl: Soquet, + a: Soquet, + b: Soquet, + x: Soquet, + y: Soquet, + lam: Soquet, + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Store (x - a) * lam % p in z1 (= (y - b) % p). + x, lam, z1, z2, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=x, + y=lam, + ) + + # If ctrl: subtract z1 from y (= 0). + ctrl, z1, y = bb.add(CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=z1, y=y) + + # Uncompute original multiplication. + x, lam = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=x, + y=lam, + target=z1, + qrom_indices=z2, + reduced=reduced, + ) + + # z1 = a. + z1 = bb.add(IntState(bitsize=self.n, val=0)) + a_split = bb.split(a) + z1_split = bb.split(z1) + for i in range(self.n): + a_split[i], z1_split[i] = bb.add(CNOT(), ctrl=a_split[i], target=z1_split[i]) + a = bb.join(a_split, QMontgomeryUInt(self.n)) + z1 = bb.join(z1_split, QMontgomeryUInt(self.n)) + + # z1 = (3 * a) % p. + z1 = bb.add(ModDbl(QMontgomeryUInt(self.n), mod=self.mod), x=z1) + a, z1 = bb.add(ModAdd(self.n, mod=self.mod), x=a, y=z1) + + # If ctrl: x = (x + 2 * a) % p. + ctrl, z1, x = bb.add(CModAdd(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=z1, y=x) + + # Uncompute z1. + a, z1 = bb.add(ModAdd(self.n, mod=self.mod).adjoint(), x=a, y=z1) + z1 = bb.add(ModDbl(QMontgomeryUInt(self.n), mod=self.mod).adjoint(), x=z1) + a_split = bb.split(a) + z1_split = bb.split(z1) + for i in range(self.n): + a_split[i], z1_split[i] = bb.add(CNOT(), ctrl=a_split[i], target=z1_split[i]) + a = bb.join(a_split, QMontgomeryUInt(self.n)) + z1 = bb.join(z1_split, QMontgomeryUInt(self.n)) + bb.add(Free(QMontgomeryUInt(self.n)), reg=z1) + + # Return the output registers. + return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return { + CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(): 1, + CNOT(): 2 * self.n, + ModDbl(QMontgomeryUInt(self.n), mod=self.mod): 1, + ModAdd(self.n, mod=self.mod): 1, + CModAdd(QMontgomeryUInt(self.n), mod=self.mod): 1, + ModAdd(self.n, mod=self.mod).adjoint(): 1, + ModDbl(QMontgomeryUInt(self.n), mod=self.mod).adjoint(): 1, + } + + +@frozen +class _ECAddStepFour(Bloq): + r"""Performs step four of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. + + Registers: + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + lam: The lambda slope used in the addition operation. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + window_size: int = 1 + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, x: 'ClassicalValT', y: 'ClassicalValT', lam: 'ClassicalValT' + ) -> Dict[str, 'ClassicalValT']: + x = ( + x - QMontgomeryUInt(self.n).montgomery_product(int(lam), int(lam), self.mod) + ) % self.mod + if lam > 0: + y = QMontgomeryUInt(self.n).montgomery_product(int(x), int(lam), self.mod) + return {'x': x, 'y': y, 'lam': lam} + + def build_composite_bloq( + self, bb: 'BloqBuilder', x: Soquet, y: Soquet, lam: Soquet + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Initialize z4 = lam. + z4 = bb.add(IntState(bitsize=self.n, val=0)) + lam_split = bb.split(lam) + z4_split = bb.split(z4) + for i in range(self.n): + lam_split[i], z4_split[i] = bb.add(CNOT(), ctrl=lam_split[i], target=z4_split[i]) + lam = bb.join(lam_split, QMontgomeryUInt(self.n)) + z4 = bb.join(z4_split, QMontgomeryUInt(self.n)) + + # z3 = lam * lam % p. + z4, lam, z3, z2, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=z4, + y=lam, + ) + + # x = a - x_r % p. + z3, x = bb.add(ModSub(QMontgomeryUInt(self.n), mod=self.mod), x=z3, y=x) + + # Uncompute the multiplication and initialization of z4. + z4, lam = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=z4, + y=lam, + target=z3, + qrom_indices=z2, + reduced=reduced, + ) + lam_split = bb.split(lam) + z4_split = bb.split(z4) + for i in range(self.n): + lam_split[i], z4_split[i] = bb.add(CNOT(), ctrl=lam_split[i], target=z4_split[i]) + lam = bb.join(lam_split, QMontgomeryUInt(self.n)) + z4 = bb.join(z4_split, QMontgomeryUInt(self.n)) + bb.add(Free(QMontgomeryUInt(self.n)), reg=z4) + + # z3 = lam * x % p. + x, lam, z3, z4, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=x, + y=lam, + ) + + # y = y_r + b % p. + z3_split = bb.split(z3) + y_split = bb.split(y) + for i in range(self.n): + z3_split[i], y_split[i] = bb.add(CNOT(), ctrl=z3_split[i], target=y_split[i]) + z3 = bb.join(z3_split, QMontgomeryUInt(self.n)) + y = bb.join(y_split, QMontgomeryUInt(self.n)) + + # Uncompute multiplication. + x, lam = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=x, + y=lam, + target=z3, + qrom_indices=z4, + reduced=reduced, + ) + + # Return the output registers. + return {'x': x, 'y': y, 'lam': lam} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return { + ModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ): 2, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(): 2, + CNOT(): 3 * self.n, + } + + +@frozen +class _ECAddStepFive(Bloq): + r"""Performs step five of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. + + Registers: + ctrl: Flag set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + lam: The lambda slope used in the addition operation. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + window_size: int = 1 + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('ctrl', QBit()), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam', QMontgomeryUInt(self.n), side=Side.LEFT), + ] + ) + + def on_classical_vals( + self, + ctrl: 'ClassicalValT', + a: 'ClassicalValT', + b: 'ClassicalValT', + x: 'ClassicalValT', + y: 'ClassicalValT', + lam: 'ClassicalValT', + ) -> Dict[str, 'ClassicalValT']: + if ctrl == 1: + x = (a - x) % self.mod + y = (y - b) % self.mod + else: + x = (x + a) % self.mod + return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + ctrl: Soquet, + a: Soquet, + b: Soquet, + x: Soquet, + y: Soquet, + lam: Soquet, + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # x = x ^ -1 % p. + x, z1, z2 = bb.add(ModInv(n=self.n, mod=self.mod), x=x) + + # z4 = x * y % p. + x, y, z4, z3, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=x, + y=y, + ) + + # If ctrl: lam = 0. + z4_split = bb.split(z4) + lam_split = bb.split(lam) + for i in range(self.n): + ctrls = [ctrl, z4_split[i]] + ctrls, lam_split[i] = bb.add( + MultiControlX(cvs=[1, 1]), controls=ctrls, target=lam_split[i] + ) + ctrl = ctrls[0] + z4_split[i] = ctrls[1] + z4 = bb.join(z4_split, dtype=QMontgomeryUInt(self.n)) + lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n)) + # TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit where lambda + # is not set to 0 before being freed. + bb.add(Free(QMontgomeryUInt(self.n), dirty=True), reg=lam) + + # Uncompute multiplication and inverse. + x, y = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=x, + y=y, + target=z4, + qrom_indices=z3, + reduced=reduced, + ) + x = bb.add(ModInv(n=self.n, mod=self.mod).adjoint(), x=x, garbage1=z1, garbage2=z2) + + # If ctrl: x = x_r - a % p. + ctrl, x = bb.add(CModNeg(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=x) + + # Add a to x (x = x_r). + a, x = bb.add(ModAdd(self.n, mod=self.mod), x=a, y=x) + + # If ctrl: subtract b from y (y = y_r). + ctrl, b, y = bb.add(CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=b, y=y) + + # Return the output registers. + return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return { + CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + ModInv(n=self.n, mod=self.mod): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(): 1, + ModInv(n=self.n, mod=self.mod).adjoint(): 1, + ModAdd(self.n, mod=self.mod): 1, + MultiControlX(cvs=[1, 1]): self.n, + CModNeg(QMontgomeryUInt(self.n), mod=self.mod): 1, + } + + +@frozen +class _ECAddStepSix(Bloq): + r"""Performs step six of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + + Registers: + f1: Flag to set if a = x. + f2: Flag to set if b = -y. + f3: Flag to set if (a, b) = (0, 0). + f4: Flag to set if (x, y) = (0, 0). + ctrl: Flag to set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('f1', QBit(), side=Side.LEFT), + Register('f2', QBit(), side=Side.LEFT), + Register('f3', QBit(), side=Side.LEFT), + Register('f4', QBit(), side=Side.LEFT), + Register('ctrl', QBit(), side=Side.LEFT), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, + f1: 'ClassicalValT', + f2: 'ClassicalValT', + f3: 'ClassicalValT', + f4: 'ClassicalValT', + ctrl: 'ClassicalValT', + a: 'ClassicalValT', + b: 'ClassicalValT', + x: 'ClassicalValT', + y: 'ClassicalValT', + ) -> Dict[str, 'ClassicalValT']: + if f4 == 1: + x = a + y = b + if f1 and f2: + x = 0 + y = 0 + return {'a': a, 'b': b, 'x': x, 'y': y} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + f1: Soquet, + f2: Soquet, + f3: Soquet, + f4: Soquet, + ctrl: Soquet, + a: Soquet, + b: Soquet, + x: Soquet, + y: Soquet, + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Unset control if f2, f3, and f4 flags are set. + f_ctrls = [f2, f3, f4] + f_ctrls, ctrl = bb.add(MultiControlX(cvs=[0] * 3), controls=f_ctrls, target=ctrl) + f2 = f_ctrls[0] + f3 = f_ctrls[1] + f4 = f_ctrls[2] + + # Set (x, y) to (a, b) if f4 is set. + a_split = bb.split(a) + x_split = bb.split(x) + for i in range(self.n): + toff_ctrl = [f4, a_split[i]] + toff_ctrl, x_split[i] = bb.add(Toffoli(), ctrl=toff_ctrl, target=x_split[i]) + f4 = toff_ctrl[0] + a_split[i] = toff_ctrl[1] + a = bb.join(a_split, QMontgomeryUInt(self.n)) + x = bb.join(x_split, QMontgomeryUInt(self.n)) + b_split = bb.split(b) + y_split = bb.split(y) + for i in range(self.n): + toff_ctrl = [f4, b_split[i]] + toff_ctrl, y_split[i] = bb.add(Toffoli(), ctrl=toff_ctrl, target=y_split[i]) + f4 = toff_ctrl[0] + b_split[i] = toff_ctrl[1] + b = bb.join(b_split, QMontgomeryUInt(self.n)) + y = bb.join(y_split, QMontgomeryUInt(self.n)) + + # Unset f4 if (x, y) = (a, b). + ab = bb.join(np.concatenate([bb.split(a), bb.split(b)]), dtype=QMontgomeryUInt(2 * self.n)) + xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n)) + ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4) + ab_split = bb.split(ab) + a = bb.join(ab_split[: self.n], dtype=QMontgomeryUInt(self.n)) + b = bb.join(ab_split[self.n :], dtype=QMontgomeryUInt(self.n)) + xy_split = bb.split(xy) + x = bb.join(xy_split[: self.n], dtype=QMontgomeryUInt(self.n)) + y = bb.join(xy_split[self.n :], dtype=QMontgomeryUInt(self.n)) + + # Unset f3 if (a, b) = (0, 0). + ab_arr = np.concatenate([bb.split(a), bb.split(b)]) + ab_arr, f3 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=ab_arr, target=f3) + ab_arr = np.split(ab_arr, 2) + a = bb.join(ab_arr[0], dtype=QMontgomeryUInt(self.n)) + b = bb.join(ab_arr[1], dtype=QMontgomeryUInt(self.n)) + + # If f1 and f2 are set, subtract a from x and add b to y. + ancilla = bb.add(ZeroState()) + toff_ctrl = [f1, f2] + toff_ctrl, ancilla = bb.add(Toffoli(), ctrl=toff_ctrl, target=ancilla) + ancilla, a, x = bb.add( + CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ancilla, x=a, y=x + ) + toff_ctrl, ancilla = bb.add(Toffoli(), ctrl=toff_ctrl, target=ancilla) + f1 = toff_ctrl[0] + f2 = toff_ctrl[1] + bb.add(Free(QBit()), reg=ancilla) + ancilla = bb.add(ZeroState()) + toff_ctrl = [f1, f2] + toff_ctrl, ancilla = bb.add(Toffoli(), ctrl=toff_ctrl, target=ancilla) + ancilla, b, y = bb.add( + CModAdd(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ancilla, x=b, y=y + ) + toff_ctrl, ancilla = bb.add(Toffoli(), ctrl=toff_ctrl, target=ancilla) + f1 = toff_ctrl[0] + f2 = toff_ctrl[1] + bb.add(Free(QBit()), reg=ancilla) + + # Unset f1 and f2 if (x, y) = (0, 0). + xy_arr = np.concatenate([bb.split(x), bb.split(y)]) + xy_arr, junk, out = bb.add(MultiAnd(cvs=[0] * 2 * self.n), ctrl=xy_arr) + targets = bb.join(np.array([f1, f2])) + out, targets = bb.add(MultiTargetCNOT(2), control=out, targets=targets) + targets = bb.split(targets) + f1 = targets[0] + f2 = targets[1] + xy_arr = bb.add( + MultiAnd(cvs=[0] * 2 * self.n).adjoint(), ctrl=xy_arr, junk=junk, target=out + ) + xy_arr = np.split(xy_arr, 2) + x = bb.join(xy_arr[0], dtype=QMontgomeryUInt(self.n)) + y = bb.join(xy_arr[1], dtype=QMontgomeryUInt(self.n)) + + # Free all ancilla qubits in the zero state. + # TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bugs in circuit where f1, + # f2, and f4 are freed before being set to 0. + bb.add(Free(QBit(), dirty=True), reg=f1) + bb.add(Free(QBit(), dirty=True), reg=f2) + bb.add(Free(QBit()), reg=f3) + bb.add(Free(QBit(), dirty=True), reg=f4) + bb.add(Free(QBit()), reg=ctrl) + + # Return the output registers. + return {'a': a, 'b': b, 'x': x, 'y': y} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + cvs: Union[list[int], HasLength] + if isinstance(self.n, int): + cvs = [0] * 2 * self.n + else: + cvs = HasLength(2 * self.n) + return { + MultiControlX(cvs=cvs): 1, + MultiControlX(cvs=[0] * 3): 1, + CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + CModAdd(QMontgomeryUInt(self.n), mod=self.mod): 1, + Toffoli(): 2 * self.n + 4, + Equals(QMontgomeryUInt(2 * self.n)): 1, + MultiAnd(cvs=cvs): 1, + MultiTargetCNOT(2): 1, + MultiAnd(cvs=cvs).adjoint(): 1, + } @frozen @@ -39,18 +976,24 @@ class ECAdd(Bloq): This takes elliptic curve points given by (a, b) and (x, y) and outputs the sum (x_r, y_r) in the second pair of registers. + Because the decomposition of this Bloq is complex, we split it into six separate parts + corresponding to the parts described in figure 10 of the Litinski paper cited below. We follow + the signature from figure 5 and break down the further decompositions based on the steps in + figure 10. + Args: n: The bitsize of the two registers storing the elliptic curve point mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. Registers: - a: The x component of the first input elliptic curve point of bitsize `n`. - b: The y component of the first input elliptic curve point of bitsize `n`. - x: The x component of the second input elliptic curve point of bitsize `n`, which + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which will contain the x component of the resultant curve point. - y: The y component of the second input elliptic curve point of bitsize `n`, which + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which will contain the y component of the resultant curve point. - lam: The precomputed lambda slope used in the addition operation. + lam_r: The precomputed lambda slope used in the addition operation if (a, b) = (x, y) in montgomery form. References: [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585). @@ -59,32 +1002,108 @@ class ECAdd(Bloq): n: int mod: int + window_size: int = 1 @cached_property def signature(self) -> 'Signature': return Signature( [ - Register('a', QUInt(self.n)), - Register('b', QUInt(self.n)), - Register('x', QUInt(self.n)), - Register('y', QUInt(self.n)), - Register('lam', QUInt(self.n)), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam_r', QMontgomeryUInt(self.n)), ] ) - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> BloqCountDictT: - # litinksi + def build_composite_bloq( + self, bb: 'BloqBuilder', a: Soquet, b: Soquet, x: Soquet, y: Soquet, lam_r: Soquet + ) -> Dict[str, 'SoquetT']: + f1, f2, f3, f4, ctrl, a, b, x, y = bb.add( + _ECAddStepOne(n=self.n, mod=self.mod), a=a, b=b, x=x, y=y + ) + f1, ctrl, a, b, x, y, lam, lam_r = bb.add( + _ECAddStepTwo(n=self.n, mod=self.mod, window_size=self.window_size), + f1=f1, + ctrl=ctrl, + a=a, + b=b, + x=x, + y=y, + lam_r=lam_r, + ) + ctrl, a, b, x, y, lam = bb.add( + _ECAddStepThree(n=self.n, mod=self.mod, window_size=self.window_size), + ctrl=ctrl, + a=a, + b=b, + x=x, + y=y, + lam=lam, + ) + x, y, lam = bb.add( + _ECAddStepFour(n=self.n, mod=self.mod, window_size=self.window_size), x=x, y=y, lam=lam + ) + ctrl, a, b, x, y = bb.add( + _ECAddStepFive(n=self.n, mod=self.mod, window_size=self.window_size), + ctrl=ctrl, + a=a, + b=b, + x=x, + y=y, + lam=lam, + ) + a, b, x, y = bb.add( + _ECAddStepSix(n=self.n, mod=self.mod), + f1=f1, + f2=f2, + f3=f3, + f4=f4, + ctrl=ctrl, + a=a, + b=b, + x=x, + y=y, + ) + + return {'a': a, 'b': b, 'x': x, 'y': y, 'lam_r': lam_r} + + def on_classical_vals(self, a, b, x, y, lam_r) -> Dict[str, Union['ClassicalValT', sympy.Expr]]: + curve_a = ( + QMontgomeryUInt(self.n).montgomery_to_uint(lam_r, self.mod) + * 2 + * QMontgomeryUInt(self.n).montgomery_to_uint(b, self.mod) + - (3 * QMontgomeryUInt(self.n).montgomery_to_uint(a, self.mod) ** 2) + ) % self.mod + p1 = ECPoint( + QMontgomeryUInt(self.n).montgomery_to_uint(a, self.mod), + QMontgomeryUInt(self.n).montgomery_to_uint(b, self.mod), + mod=self.mod, + curve_a=curve_a, + ) + p2 = ECPoint( + QMontgomeryUInt(self.n).montgomery_to_uint(x, self.mod), + QMontgomeryUInt(self.n).montgomery_to_uint(y, self.mod), + mod=self.mod, + curve_a=curve_a, + ) + result = p1 + p2 + return { + 'a': a, + 'b': b, + 'x': QMontgomeryUInt(self.n).uint_to_montgomery(result.x, self.mod), + 'y': QMontgomeryUInt(self.n).uint_to_montgomery(result.y, self.mod), + 'lam_r': lam_r, + } + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': return { - MultiCToffoli(n=self.n): 18, - ModAdd(bitsize=self.n, mod=self.mod): 3, - CModAdd(QUInt(self.n), mod=self.mod): 2, - ModSub(QUInt(self.n), mod=self.mod): 2, - CModSub(QUInt(self.n), mod=self.mod): 4, - ModNeg(QUInt(self.n), mod=self.mod): 2, - CModNeg(QUInt(self.n), mod=self.mod): 1, - ModDbl(QUInt(self.n), mod=self.mod): 2, - DirtyOutOfPlaceMontgomeryModMul(bitsize=self.n, window_size=4, mod=self.mod): 10, - ModInv(n=self.n, mod=self.mod): 4, + _ECAddStepOne(n=self.n, mod=self.mod): 1, + _ECAddStepTwo(n=self.n, mod=self.mod, window_size=self.window_size): 1, + _ECAddStepThree(n=self.n, mod=self.mod, window_size=self.window_size): 1, + _ECAddStepFour(n=self.n, mod=self.mod, window_size=self.window_size): 1, + _ECAddStepFive(n=self.n, mod=self.mod, window_size=self.window_size): 1, + _ECAddStepSix(n=self.n, mod=self.mod): 1, } @@ -95,4 +1114,10 @@ def _ec_add() -> ECAdd: return ec_add -_EC_ADD_DOC = BloqDocSpec(bloq_cls=ECAdd, examples=[_ec_add]) +@bloq_example +def _ec_add_small() -> ECAdd: + ec_add_small = ECAdd(5, mod=7) + return ec_add_small + + +_EC_ADD_DOC = BloqDocSpec(bloq_cls=ECAdd, examples=[_ec_add, _ec_add_small]) diff --git a/qualtran/bloqs/factoring/ecc/ec_add_test.py b/qualtran/bloqs/factoring/ecc/ec_add_test.py index 44a8b77e1..5295316f1 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add_test.py +++ b/qualtran/bloqs/factoring/ecc/ec_add_test.py @@ -12,13 +12,423 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest +import sympy + import qualtran.testing as qlt_testing -from qualtran.bloqs.factoring.ecc.ec_add import _ec_add +from qualtran._infra.data_types import QMontgomeryUInt +from qualtran.bloqs.factoring.ecc.ec_add import ( + _ec_add, + _ec_add_small, + _ECAddStepFive, + _ECAddStepFour, + _ECAddStepOne, + _ECAddStepSix, + _ECAddStepThree, + _ECAddStepTwo, + ECAdd, +) +from qualtran.resource_counting._bloq_counts import QECGatesCost +from qualtran.resource_counting._costing import get_cost_value +from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join + + +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(7, 8) for m in range(1, n + 1) if n % m == 0] +) +@pytest.mark.parametrize('a,b', [(15, 13), (2, 10)]) +@pytest.mark.parametrize('x,y', [(15, 13), (0, 0)]) +def test_ec_add_steps_classical_fast(n, m, a, b, x, y): + p = 17 + lam_num = (3 * a**2) % p + lam_denom = (2 * b) % p + lam_r = 0 if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p + + a = QMontgomeryUInt(n).uint_to_montgomery(a, p) + b = QMontgomeryUInt(n).uint_to_montgomery(b, p) + x = QMontgomeryUInt(n).uint_to_montgomery(x, p) + y = QMontgomeryUInt(n).uint_to_montgomery(y, p) + lam_r = QMontgomeryUInt(n).uint_to_montgomery(lam_r, p) if lam_r != 0 else p + + bloq = _ECAddStepOne(n=n, mod=p) + ret1 = bloq.call_classically(a=a, b=b, x=x, y=y) + ret2 = bloq.decompose_bloq().call_classically(a=a, b=b, x=x, y=y) + assert ret1 == ret2 + + step_1 = _ECAddStepOne(n=n, mod=p).on_classical_vals(a=a, b=b, x=x, y=y) + bloq = _ECAddStepTwo(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + ret2 = bloq.decompose_bloq().call_classically( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + assert ret1 == ret2 + + step_2 = _ECAddStepTwo(n=n, mod=p, window_size=m).on_classical_vals( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + bloq = _ECAddStepThree(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + ret2 = bloq.decompose_bloq().call_classically( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + assert ret1 == ret2 + + step_3 = _ECAddStepThree(n=n, mod=p, window_size=m).on_classical_vals( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + bloq = _ECAddStepFour(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically(x=step_3['x'], y=step_3['y'], lam=step_3['lam']) + ret2 = bloq.decompose_bloq().call_classically(x=step_3['x'], y=step_3['y'], lam=step_3['lam']) + assert ret1 == ret2 + + step_4 = _ECAddStepFour(n=n, mod=p, window_size=m).on_classical_vals( + x=step_3['x'], y=step_3['y'], lam=step_3['lam'] + ) + bloq = _ECAddStepFive(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + ret2 = bloq.decompose_bloq().call_classically( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + assert ret1 == ret2 + + step_5 = _ECAddStepFive(n=n, mod=p, window_size=m).on_classical_vals( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + bloq = _ECAddStepSix(n=n, mod=p) + ret1 = bloq.call_classically( + f1=step_2['f1'], + f2=step_1['f2'], + f3=step_1['f3'], + f4=step_1['f4'], + ctrl=step_5['ctrl'], + a=step_5['a'], + b=step_5['b'], + x=step_5['x'], + y=step_5['y'], + ) + ret2 = bloq.decompose_bloq().call_classically( + f1=step_2['f1'], + f2=step_1['f2'], + f3=step_1['f3'], + f4=step_1['f4'], + ctrl=step_5['ctrl'], + a=step_5['a'], + b=step_5['b'], + x=step_5['x'], + y=step_5['y'], + ) + assert ret1 == ret2 + + +@pytest.mark.slow +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(7, 9) for m in range(1, n + 1) if n % m == 0] +) +@pytest.mark.parametrize( + 'a,b', + [ + (15, 13), + (2, 10), + (8, 3), + (12, 1), + (6, 6), + (5, 8), + (10, 15), + (1, 12), + (3, 0), + (1, 5), + (10, 2), + (0, 0), + ], +) +@pytest.mark.parametrize('x,y', [(15, 13), (5, 8), (10, 15), (1, 12), (3, 0), (1, 5), (10, 2)]) +def test_ec_add_steps_classical(n, m, a, b, x, y): + p = 17 + lam_num = (3 * a**2) % p + lam_denom = (2 * b) % p + lam_r = 0 if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p + + a = QMontgomeryUInt(n).uint_to_montgomery(a, p) + b = QMontgomeryUInt(n).uint_to_montgomery(b, p) + x = QMontgomeryUInt(n).uint_to_montgomery(x, p) + y = QMontgomeryUInt(n).uint_to_montgomery(y, p) + lam_r = QMontgomeryUInt(n).uint_to_montgomery(lam_r, p) if lam_r != 0 else p + + bloq = _ECAddStepOne(n=n, mod=p) + ret1 = bloq.call_classically(a=a, b=b, x=x, y=y) + ret2 = bloq.decompose_bloq().call_classically(a=a, b=b, x=x, y=y) + assert ret1 == ret2 + + step_1 = _ECAddStepOne(n=n, mod=p).on_classical_vals(a=a, b=b, x=x, y=y) + bloq = _ECAddStepTwo(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + ret2 = bloq.decompose_bloq().call_classically( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + assert ret1 == ret2 + + step_2 = _ECAddStepTwo(n=n, mod=p, window_size=m).on_classical_vals( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + bloq = _ECAddStepThree(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + ret2 = bloq.decompose_bloq().call_classically( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + assert ret1 == ret2 + + step_3 = _ECAddStepThree(n=n, mod=p, window_size=m).on_classical_vals( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + bloq = _ECAddStepFour(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically(x=step_3['x'], y=step_3['y'], lam=step_3['lam']) + ret2 = bloq.decompose_bloq().call_classically(x=step_3['x'], y=step_3['y'], lam=step_3['lam']) + assert ret1 == ret2 + + step_4 = _ECAddStepFour(n=n, mod=p, window_size=m).on_classical_vals( + x=step_3['x'], y=step_3['y'], lam=step_3['lam'] + ) + bloq = _ECAddStepFive(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + ret2 = bloq.decompose_bloq().call_classically( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + assert ret1 == ret2 + + step_5 = _ECAddStepFive(n=n, mod=p, window_size=m).on_classical_vals( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + bloq = _ECAddStepSix(n=n, mod=p) + ret1 = bloq.call_classically( + f1=step_2['f1'], + f2=step_1['f2'], + f3=step_1['f3'], + f4=step_1['f4'], + ctrl=step_5['ctrl'], + a=step_5['a'], + b=step_5['b'], + x=step_5['x'], + y=step_5['y'], + ) + ret2 = bloq.decompose_bloq().call_classically( + f1=step_2['f1'], + f2=step_1['f2'], + f3=step_1['f3'], + f4=step_1['f4'], + ctrl=step_5['ctrl'], + a=step_5['a'], + b=step_5['b'], + x=step_5['x'], + y=step_5['y'], + ) + assert ret1 == ret2 + + +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(7, 8) for m in range(1, n + 1) if n % m == 0] +) +@pytest.mark.parametrize('a,b', [(15, 13), (2, 10)]) +@pytest.mark.parametrize('x,y', [(15, 13), (0, 0)]) +def test_ec_add_classical_fast(n, m, a, b, x, y): + p = 17 + bloq = ECAdd(n=n, mod=p, window_size=m) + lam_num = (3 * a**2) % p + lam_denom = (2 * b) % p + lam_r = p if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p + ret1 = bloq.call_classically( + a=QMontgomeryUInt(n).uint_to_montgomery(a, p), + b=QMontgomeryUInt(n).uint_to_montgomery(b, p), + x=QMontgomeryUInt(n).uint_to_montgomery(x, p), + y=QMontgomeryUInt(n).uint_to_montgomery(y, p), + lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p), + ) + ret2 = bloq.decompose_bloq().call_classically( + a=QMontgomeryUInt(n).uint_to_montgomery(a, p), + b=QMontgomeryUInt(n).uint_to_montgomery(b, p), + x=QMontgomeryUInt(n).uint_to_montgomery(x, p), + y=QMontgomeryUInt(n).uint_to_montgomery(y, p), + lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p), + ) + assert ret1 == ret2 + + +@pytest.mark.slow +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(7, 9) for m in range(1, n + 1) if n % m == 0] +) +@pytest.mark.parametrize( + 'a,b', + [ + (15, 13), + (2, 10), + (8, 3), + (12, 1), + (6, 6), + (5, 8), + (10, 15), + (1, 12), + (3, 0), + (1, 5), + (10, 2), + (0, 0), + ], +) +@pytest.mark.parametrize('x,y', [(15, 13), (5, 8), (10, 15), (1, 12), (3, 0), (1, 5), (10, 2)]) +def test_ec_add_classical(n, m, a, b, x, y): + p = 17 + bloq = ECAdd(n=n, mod=p, window_size=m) + lam_num = (3 * a**2) % p + lam_denom = (2 * b) % p + lam_r = p if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p + ret1 = bloq.call_classically( + a=QMontgomeryUInt(n).uint_to_montgomery(a, p), + b=QMontgomeryUInt(n).uint_to_montgomery(b, p), + x=QMontgomeryUInt(n).uint_to_montgomery(x, p), + y=QMontgomeryUInt(n).uint_to_montgomery(y, p), + lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p), + ) + ret2 = bloq.decompose_bloq().call_classically( + a=QMontgomeryUInt(n).uint_to_montgomery(a, p), + b=QMontgomeryUInt(n).uint_to_montgomery(b, p), + x=QMontgomeryUInt(n).uint_to_montgomery(x, p), + y=QMontgomeryUInt(n).uint_to_montgomery(y, p), + lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p), + ) + assert ret1 == ret2 + + +@pytest.mark.parametrize('p', (7, 9, 11)) +@pytest.mark.parametrize( + ['n', 'window_size'], + [ + (n, window_size) + for n in range(5, 8) + for window_size in range(1, n + 1) + if n % window_size == 0 + ], +) +def test_ec_add_decomposition(n, window_size, p): + b = ECAdd(n=n, window_size=window_size, mod=p) + qlt_testing.assert_valid_bloq_decomposition(b) + + +@pytest.mark.parametrize('p', (7, 9, 11)) +@pytest.mark.parametrize( + ['n', 'window_size'], + [ + (n, window_size) + for n in range(5, 8) + for window_size in range(1, n + 1) + if n % window_size == 0 + ], +) +def test_ec_add_bloq_counts(n, window_size, p): + b = ECAdd(n=n, window_size=window_size, mod=p) + qlt_testing.assert_equivalent_bloq_counts(b, [ignore_alloc_free, ignore_split_join]) + + +def test_ec_add_symbolic_cost(): + n, m, p = sympy.symbols('n m p', integer=True) + + # In Litinski 2023 https://arxiv.org/abs/2306.08585 a window size of 4 is used. + # The cost function generally has floor/ceil division that disappear for bitsize=0 mod 4. + # This is why instead of using bitsize=n directly, we use bitsize=4*m=n. + b = ECAdd(n=4 * m, window_size=4, mod=p) + cost = get_cost_value(b, QECGatesCost()).total_t_and_ccz_count() + assert cost['n_t'] == 0 + + # Litinski 2023 https://arxiv.org/abs/2306.08585 + # Based on the counts from Figures 3, 5, and 8 the toffoli count for ECAdd is 126.5n^2 + 189n. + # The following formula is 126.5n^2 + 175.5n - 35. We account for the discrepancy in the + # coefficient of n by a reduction in the toffoli cost of Montgomery ModMult, n extra toffolis + # in ModNeg, and 2n extra toffolis to do n 3-controlled toffolis in step 2. The expression is + # written with rationals because sympy comparison fails with floats. + assert isinstance(cost['n_ccz'], sympy.Expr) + assert ( + cost['n_ccz'].subs(m, n / 4).expand() + == sympy.Rational(253, 2) * n**2 + sympy.Rational(351, 2) * n - 35 + ) def test_ec_add(bloq_autotester): bloq_autotester(_ec_add) +def test_ec_add_small(bloq_autotester): + bloq_autotester(_ec_add_small) + + def test_notebook(): qlt_testing.execute_notebook('ec_add') diff --git a/qualtran/bloqs/factoring/ecc/ec_point.py b/qualtran/bloqs/factoring/ecc/ec_point.py index 968ebe0b5..c17ea5957 100644 --- a/qualtran/bloqs/factoring/ecc/ec_point.py +++ b/qualtran/bloqs/factoring/ecc/ec_point.py @@ -69,7 +69,8 @@ def __add__(self, other): return ECPoint(xr, yr, mod=self.mod, curve_a=self.curve_a) def __mul__(self, other): - assert other > 0, other + if other == 0: + return ECPoint.inf(mod=self.mod, curve_a=self.curve_a) x = self for _ in range(other - 1): x = x + self diff --git a/qualtran/bloqs/factoring/ecc/ec_point_test.py b/qualtran/bloqs/factoring/ecc/ec_point_test.py index f65981c80..1ac1b59bd 100644 --- a/qualtran/bloqs/factoring/ecc/ec_point_test.py +++ b/qualtran/bloqs/factoring/ecc/ec_point_test.py @@ -21,6 +21,7 @@ def test_ec_point_overrides(): assert 1 * p == p assert 2 * p == (p + p) assert 3 * p == (p + p + p) + assert 0 * p == ECPoint.inf(mod=17, curve_a=0) def test_ec_point_addition(): diff --git a/qualtran/bloqs/mod_arithmetic/_shims.py b/qualtran/bloqs/mod_arithmetic/_shims.py index 65b61b4ca..7242f6afa 100644 --- a/qualtran/bloqs/mod_arithmetic/_shims.py +++ b/qualtran/bloqs/mod_arithmetic/_shims.py @@ -24,14 +24,17 @@ from functools import cached_property from typing import Dict, Optional, Tuple, TYPE_CHECKING +import attrs from attrs import frozen -from qualtran import Bloq, QUInt, Register, Signature -from qualtran.bloqs.arithmetic import Add, AddK, Negate, Subtract -from qualtran.bloqs.arithmetic._shims import CHalf, Lt, MultiCToffoli +from qualtran import Bloq, QUInt, Register, Side, Signature +from qualtran.bloqs.arithmetic import AddK, Negate +from qualtran.bloqs.arithmetic._shims import CHalf, CSub, Lt, MultiCToffoli +from qualtran.bloqs.arithmetic.controlled_addition import CAdd from qualtran.bloqs.basic_gates import CNOT, CSwap, Swap, Toffoli from qualtran.bloqs.mod_arithmetic.mod_multiplication import ModDbl from qualtran.drawing import Text, TextBox, WireSymbol +from qualtran.simulation.classical_sim import ClassicalValT if TYPE_CHECKING: from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator @@ -57,9 +60,9 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': (CNOT(), 2), (Lt(self.n), 1), (CSwap(self.n), 2), - (Subtract(QUInt(self.n)), 1), - (Add(QUInt(self.n)), 1), - (CNOT(), 1), + (CSub(self.n), 1), + (CAdd(QUInt(self.n)), 1), + (CNOT(), 2), (ModDbl(QUInt(self.n), self.mod), 1), (CHalf(self.n), 1), (CSwap(self.n), 2), @@ -88,10 +91,21 @@ def wire_symbol( class ModInv(Bloq): n: int mod: int + uncompute: bool = False @cached_property def signature(self) -> 'Signature': - return Signature([Register('x', QUInt(self.n)), Register('out', QUInt(self.n))]) + side = Side.LEFT if self.uncompute else Side.RIGHT + return Signature( + [ + Register('x', QUInt(self.n)), + Register('garbage1', QUInt(self.n), side=side), + Register('garbage2', QUInt(self.n), side=side), + ] + ) + + def adjoint(self) -> 'ModInv': + return attrs.evolve(self, uncompute=self.uncompute ^ True) def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': # Roetteler @@ -103,6 +117,29 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': Swap(self.n): 1, } + def on_classical_vals( + self, + x: 'ClassicalValT', + garbage1: Optional['ClassicalValT'] = None, + garbage2: Optional['ClassicalValT'] = None, + ) -> Dict[str, ClassicalValT]: + # TODO(https://github.com/quantumlib/Qualtran/issues/1443): Hacky classical simulation just + # to confirm correctness of ECAdd circuit. + if self.uncompute: + assert garbage1 is not None + assert garbage2 is not None + return {'x': garbage1} + assert garbage1 is None + assert garbage2 is None + + # Store the original x in the garbage registers for the uncompute simulation. + garbage1 = x + garbage2 = x + + x = pow(int(x), self.mod - 2, mod=self.mod) * pow(2, 2 * self.n, self.mod) % self.mod + + return {'x': x, 'garbage1': garbage1, 'garbage2': garbage2} + def wire_symbol( self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple() ) -> 'WireSymbol': @@ -112,4 +149,8 @@ def wire_symbol( return TextBox('x') elif reg.name == 'out': return TextBox('$x^{-1}$') + elif reg.name == 'garbage1': + return TextBox('garbage1') + elif reg.name == 'garbage2': + return TextBox('garbage2') raise ValueError(f'Unrecognized register name {reg.name}') diff --git a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py index 95f74edc7..9f289add1 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py +++ b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py @@ -650,7 +650,8 @@ def on_classical_vals( raise ValueError(f'classical action is not supported for {self}') if self.uncompute: assert ( - target is not None and target == (x * y * pow(2, self.bitsize, self.mod)) % self.mod + target is not None + and target == (x * y * pow(2, self.bitsize * (self.mod - 2), self.mod)) % self.mod ) assert qrom_indices is not None assert reduced is not None diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index bc3be1a65..347c78b61 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -98,6 +98,7 @@ import qualtran.bloqs.data_loading.qrom import qualtran.bloqs.data_loading.select_swap_qrom import qualtran.bloqs.factoring._factoring_shims +import qualtran.bloqs.factoring.ecc.ec_add import qualtran.bloqs.factoring.rsa import qualtran.bloqs.for_testing.atom import qualtran.bloqs.for_testing.casting @@ -348,6 +349,13 @@ "qualtran.bloqs.mod_arithmetic.mod_multiplication.DirtyOutOfPlaceMontgomeryModMul": qualtran.bloqs.mod_arithmetic.mod_multiplication.DirtyOutOfPlaceMontgomeryModMul, "qualtran.bloqs.mod_arithmetic.mod_multiplication.SingleWindowModMul": qualtran.bloqs.mod_arithmetic.mod_multiplication.SingleWindowModMul, "qualtran.bloqs.factoring._factoring_shims.MeasureQFT": qualtran.bloqs.factoring._factoring_shims.MeasureQFT, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepOne": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepOne, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepTwo": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepTwo, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepThree": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepThree, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepFour": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepFour, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepFive": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepFive, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepSix": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepSix, + "qualtran.bloqs.factoring.ecc.ec_add.ECAdd": qualtran.bloqs.factoring.ecc.ec_add.ECAdd, "qualtran.bloqs.factoring.rsa.rsa_phase_estimate.RSAPhaseEstimate": qualtran.bloqs.factoring.rsa.rsa_phase_estimate.RSAPhaseEstimate, "qualtran.bloqs.factoring.rsa.rsa_mod_exp.ModExp": qualtran.bloqs.factoring.rsa.rsa_mod_exp.ModExp, "qualtran.bloqs.for_testing.atom.TestAtom": qualtran.bloqs.for_testing.atom.TestAtom,