Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Decomposition of ECWindowAddR #1477

Merged
merged 13 commits into from
Nov 4, 2024
146 changes: 130 additions & 16 deletions qualtran/bloqs/factoring/ecc/ec_add_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,30 @@
from functools import cached_property
from typing import Dict, Optional, Tuple, Union

import numpy as np
import sympy
from attrs import frozen

from qualtran import Bloq, bloq_example, BloqDocSpec, QBit, QUInt, Register, Signature
from qualtran import (
Bloq,
bloq_example,
BloqBuilder,
BloqDocSpec,
QBit,
QMontgomeryUInt,
QUInt,
Register,
Signature,
Soquet,
SoquetT,
)
from qualtran.bloqs.data_loading import QROAMClean
from qualtran.drawing import Circle, Text, TextBox, WireSymbol
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.symbolics import is_symbolic, Shaped

from .ec_add import ECAdd
from .ec_point import ECPoint


Expand Down Expand Up @@ -113,33 +130,133 @@ class ECWindowAddR(Bloq):

Args:
n: The bitsize of the two registers storing the elliptic curve point
window_size: The number of bits in the window.
R: The elliptic curve point to add.
R: The elliptic curve point to add (NOT in montgomery form).
add_window_size: The number of bits in the ECAdd window.
mul_window_size: The number of bits in the modular multiplication window.

Registers:
ctrl: `window_size` control bits.
x: The x component of the input elliptic curve point of bitsize `n`.
y: The y component of the input elliptic curve point of bitsize `n`.
x: The x component of the input elliptic curve point of bitsize `n` in montgomery form.
y: The y component of the input elliptic curve point of bitsize `n` in montgomery form.

References:
[How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585).
Litinski. 2013. Section 1, eq. (3) and (4).
"""

n: int
window_size: int
R: ECPoint
add_window_size: int
mul_window_size: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a window size of 1 means normal multiplication, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct


@cached_property
def signature(self) -> 'Signature':
return Signature(
[
Register('ctrl', QBit(), shape=(self.window_size,)),
Register('ctrl', QBit(), shape=(self.add_window_size,)),
Register('x', QUInt(self.n)),
Register('y', QUInt(self.n)),
]
)

@cached_property
def qrom(self) -> QROAMClean:
if is_symbolic(self.n) or is_symbolic(self.add_window_size):
log_block_sizes = None
if is_symbolic(self.n) and not is_symbolic(self.add_window_size):
# We assume that bitsize is much larger than window_size
log_block_sizes = (0,)
return QROAMClean(
[
Shaped((2**self.add_window_size,)),
Shaped((2**self.add_window_size,)),
Shaped((2**self.add_window_size,)),
],
selection_bitsizes=(self.add_window_size,),
target_bitsizes=(self.n, self.n, self.n),
log_block_sizes=log_block_sizes,
)

cR = self.R
data_a, data_b, data_lam = [0], [0], [0]
for _ in range(1, 2**self.add_window_size):
data_a.append(QMontgomeryUInt(self.n).uint_to_montgomery(int(cR.x), int(self.R.mod)))
fpapa250 marked this conversation as resolved.
Show resolved Hide resolved
data_b.append(QMontgomeryUInt(self.n).uint_to_montgomery(int(cR.y), int(self.R.mod)))
lam_num = (3 * cR.x**2 + cR.curve_a) % cR.mod
lam_denom = (2 * cR.y) % cR.mod
if lam_denom != 0:
lam = (lam_num * pow(lam_denom, -1, mod=cR.mod)) % cR.mod
else:
lam = 0
data_lam.append(QMontgomeryUInt(self.n).uint_to_montgomery(int(lam), int(self.R.mod)))
cR = cR + self.R

return QROAMClean(
[data_a, data_b, data_lam],
selection_bitsizes=(self.add_window_size,),
target_bitsizes=(self.n, self.n, self.n),
)

def build_composite_bloq(
self, bb: 'BloqBuilder', ctrl: 'SoquetT', x: 'Soquet', y: 'Soquet'
) -> Dict[str, 'SoquetT']:
ctrl = bb.join(np.array(ctrl))

ctrl, a, b, lam_r, *junk = bb.add(self.qrom, selection=ctrl)

a, b, x, y, lam_r = bb.add(
# TODO(https://github.com/quantumlib/Qualtran/issues/1476): make ECAdd accept SymbolicInt.
ECAdd(n=self.n, mod=int(self.R.mod), window_size=self.mul_window_size),
a=a,
b=b,
x=x,
y=y,
lam_r=lam_r,
)

if junk:
assert len(junk) == 3
ctrl = bb.add(
self.qrom.adjoint(),
selection=ctrl,
target0_=a,
target1_=b,
target2_=lam_r,
junk_target0_=junk[0],
junk_target1_=junk[1],
junk_target2_=junk[2],
)
else:
ctrl = bb.add(
self.qrom.adjoint(), selection=ctrl, target0_=a, target1_=b, target2_=lam_r
)

return {'ctrl': bb.split(ctrl), 'x': x, 'y': y}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {
self.qrom: 1,
# TODO(https://github.com/quantumlib/Qualtran/issues/1476): make ECAdd accept SymbolicInt.
ECAdd(self.n, int(self.R.mod), self.mul_window_size): 1,
self.qrom.adjoint(): 1,
}

def on_classical_vals(self, ctrl, x, y) -> Dict[str, Union['ClassicalValT', sympy.Expr]]:
# TODO(https://github.com/quantumlib/Qualtran/issues/1476): make ECAdd accept SymbolicInt.
A = ECPoint(
QMontgomeryUInt(self.n).montgomery_to_uint(int(x), int(self.R.mod)),
QMontgomeryUInt(self.n).montgomery_to_uint(int(y), int(self.R.mod)),
mod=self.R.mod,
curve_a=self.R.curve_a,
)
ctrls = QUInt(self.n).from_bits(ctrl)
result: ECPoint = A + (ctrls * self.R)
fpapa250 marked this conversation as resolved.
Show resolved Hide resolved
return {
'ctrl': ctrl,
'x': QMontgomeryUInt(self.n).uint_to_montgomery(int(result.x), int(self.R.mod)),
'y': QMontgomeryUInt(self.n).uint_to_montgomery(int(result.y), int(self.R.mod)),
}

def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
) -> 'WireSymbol':
Expand All @@ -153,16 +270,13 @@ def wire_symbol(
return TextBox(f'$+{self.R.y}$')
raise ValueError(f'Unrecognized register name {reg.name}')

def __str__(self):
return f'ECWindowAddR({self.n=})'


@bloq_example
def _ec_window_add() -> ECWindowAddR:
n, p = sympy.symbols('n p')
Rx, Ry = sympy.symbols('Rx Ry')
ec_window_add = ECWindowAddR(n=n, window_size=3, R=ECPoint(Rx, Ry, mod=p))
return ec_window_add
def _ec_window_add_r_small() -> ECWindowAddR:
n = 16
P = ECPoint(2, 2, mod=7, curve_a=3)
ec_window_add_r_small = ECWindowAddR(n=n, R=P, add_window_size=4)
return ec_window_add_r_small


_EC_WINDOW_ADD_BLOQ_DOC = BloqDocSpec(bloq_cls=ECWindowAddR, examples=[_ec_window_add])
_EC_WINDOW_ADD_BLOQ_DOC = BloqDocSpec(bloq_cls=ECWindowAddR, examples=[_ec_window_add_r_small])
77 changes: 70 additions & 7 deletions qualtran/bloqs/factoring/ecc/ec_add_r_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,79 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from qualtran.bloqs.factoring.ecc.ec_add_r import _ec_add_r, _ec_add_r_small, _ec_window_add
import numpy as np
import pytest

import qualtran.testing as qlt_testing
from qualtran import QMontgomeryUInt, QUInt
from qualtran.bloqs.factoring.ecc.ec_add_r import (
_ec_add_r,
_ec_add_r_small,
_ec_window_add_r_small,
ECWindowAddR,
)
from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join

def test_ec_add_r(bloq_autotester):
bloq_autotester(_ec_add_r)
from .ec_add_r import ECWindowAddR
from .ec_point import ECPoint


def test_ec_add_r_small(bloq_autotester):
bloq_autotester(_ec_add_r_small)
@pytest.mark.parametrize('bloq', [_ec_add_r, _ec_add_r_small, _ec_window_add_r_small])
def test_ec_add_r(bloq_autotester, bloq):
bloq_autotester(bloq)


def test_ec_window_add(bloq_autotester):
bloq_autotester(_ec_window_add)
@pytest.mark.parametrize('a,b', [(15, 13), (0, 0)])
@pytest.mark.parametrize(
['n', 'window_size'],
[
(n, window_size)
for n in range(5, 8)
for window_size in range(1, n + 1)
if n % window_size == 0
],
)
def test_ec_window_add_r_bloq_counts(n, window_size, a, b):
p = 17
R = ECPoint(a, b, mod=p)
bloq = ECWindowAddR(n=n, R=R, add_window_size=window_size)
qlt_testing.assert_equivalent_bloq_counts(bloq, [ignore_alloc_free, ignore_split_join])


@pytest.mark.parametrize(
['n', 'm'], [(n, m) for n in range(4, 5) for m in range(1, n + 1) if n % m == 0]
)
@pytest.mark.parametrize('a,b', [(15, 13), (0, 0)])
@pytest.mark.parametrize('x,y', [(15, 13), (5, 8)])
@pytest.mark.parametrize('ctrl', [0, 1, 5])
def test_ec_window_add_r_classical(n, m, ctrl, x, y, a, b):
p = 17
R = ECPoint(a, b, mod=p)
x = QMontgomeryUInt(n).uint_to_montgomery(x, p)
y = QMontgomeryUInt(n).uint_to_montgomery(y, p)
ctrl = np.array(QUInt(m).to_bits(ctrl % (2**m)))
bloq = ECWindowAddR(n=n, R=R, add_window_size=m, mul_window_size=m)
ret1 = bloq.call_classically(ctrl=ctrl, x=x, y=y)
ret2 = bloq.decompose_bloq().call_classically(ctrl=ctrl, x=x, y=y)
for i, ret1_i in enumerate(ret1):
np.testing.assert_array_equal(ret1_i, ret2[i])


@pytest.mark.slow
@pytest.mark.parametrize(
['n', 'm'], [(n, m) for n in range(7, 9) for m in range(1, n + 1) if n % m == 0]
)
@pytest.mark.parametrize('a,b', [(15, 13), (0, 0)])
@pytest.mark.parametrize('x,y', [(15, 13), (5, 8)])
@pytest.mark.parametrize('ctrl', [0, 1, 5, 8])
def test_ec_window_add_r_classical_slow(n, m, ctrl, x, y, a, b):
p = 17
R = ECPoint(a, b, mod=p)
x = QMontgomeryUInt(n).uint_to_montgomery(x, p)
y = QMontgomeryUInt(n).uint_to_montgomery(y, p)
ctrl = np.array(QUInt(m).to_bits(ctrl % (2**m)))
bloq = ECWindowAddR(n=n, R=R, add_window_size=m, mul_window_size=m)
ret1 = bloq.call_classically(ctrl=ctrl, x=x, y=y)
ret2 = bloq.decompose_bloq().call_classically(ctrl=ctrl, x=x, y=y)
for i, ret1_i in enumerate(ret1):
np.testing.assert_array_equal(ret1_i, ret2[i])
51 changes: 44 additions & 7 deletions qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from functools import cached_property
from typing import Dict
from typing import Dict, Union

import numpy as np
import sympy
from attrs import frozen

Expand All @@ -34,7 +36,7 @@
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator

from .._factoring_shims import MeasureQFT
from .ec_add_r import ECAddR
from .ec_add_r import ECAddR, ECWindowAddR
from .ec_point import ECPoint


Expand All @@ -45,38 +47,73 @@ class ECPhaseEstimateR(Bloq):
This is used as a subroutine in `FindECCPrivateKey`. First, we phase-estimate the
addition of the base point $P$, then of the public key $Q$.

When the ellptic curve point addition window size is 1 we use the ECAddR bloq which has it's
own bespoke circuit; when it is greater than 1 we use the windowed circuit which uses
pre-computed classical additions loaded into the circuit.

Args:
n: The bitsize of the elliptic curve points' x and y registers.
point: The elliptic curve point to phase estimate against.
add_window_size: The number of bits in the ECAdd window.
mul_window_size: The number of bits in the modular multiplication window.
"""

n: int
point: ECPoint
add_window_size: int = 1
mul_window_size: int = 1

@cached_property
def signature(self) -> 'Signature':
return Signature([Register('x', QUInt(self.n)), Register('y', QUInt(self.n))])

@property
def ec_add(self) -> Union[functools.partial[ECAddR], functools.partial[ECWindowAddR]]:
if self.add_window_size == 1:
return functools.partial(ECAddR, n=self.n)
return functools.partial(
ECWindowAddR,
n=self.n,
add_window_size=self.add_window_size,
mul_window_size=self.mul_window_size,
)

@property
def num_windows(self) -> int:
return self.n // self.add_window_size

def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[str, 'SoquetT']:
if isinstance(self.n, sympy.Expr):
raise DecomposeTypeError("Cannot decompose symbolic `n`.")
ctrl = [bb.add(PlusState()) for _ in range(self.n)]
for i in range(self.n):
ctrl[i], x, y = bb.add(ECAddR(n=self.n, R=2**i * self.point), ctrl=ctrl[i], x=x, y=y)

if self.add_window_size == 1:
for i in range(self.n):
ctrl[i], x, y = bb.add(self.ec_add(R=2**i * self.point), ctrl=ctrl[i], x=x, y=y)
else:
ctrls = np.split(np.array(ctrl), self.num_windows)
for i in range(self.num_windows):
ctrls[i], x, y = bb.add(
self.ec_add(R=2 ** (self.add_window_size * i) * self.point),
ctrl=ctrls[i],
x=x,
y=y,
)
ctrl = np.concatenate(ctrls, axis=None)

bb.add(MeasureQFT(n=self.n), x=ctrl)
return {'x': x, 'y': y}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {ECAddR(n=self.n, R=self.point): self.n, MeasureQFT(n=self.n): 1}
return {self.ec_add(R=self.point): self.num_windows, MeasureQFT(n=self.n): 1}

def __str__(self) -> str:
return f'PE${self.point}$'


@bloq_example
def _ec_pe() -> ECPhaseEstimateR:
n, p = sympy.symbols('n p ')
n, p = sympy.symbols('n p')
Rx, Ry = sympy.symbols('R_x R_y')
ec_pe = ECPhaseEstimateR(n=n, point=ECPoint(Rx, Ry, mod=p))
return ec_pe
Expand All @@ -90,4 +127,4 @@ def _ec_pe_small() -> ECPhaseEstimateR:
return ec_pe_small


_EC_PE_BLOQ_DOC = BloqDocSpec(bloq_cls=ECPhaseEstimateR, examples=[_ec_pe])
_EC_PE_BLOQ_DOC = BloqDocSpec(bloq_cls=ECPhaseEstimateR, examples=[_ec_pe, _ec_pe_small])
Loading
Loading