Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in ECAdd bloq #1489

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
148 changes: 107 additions & 41 deletions qualtran/bloqs/factoring/ecc/ec_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
bloq_example,
BloqBuilder,
BloqDocSpec,
CtrlSpec,
DecomposeTypeError,
QBit,
QMontgomeryUInt,
Expand Down Expand Up @@ -255,10 +256,6 @@ def on_classical_vals(
QMontgomeryUInt(self.n).montgomery_inverse(int(x), int(self.mod)),
int(self.mod),
)
# TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit
# which flips f1 when lam and lam_r are equal.
if lam == lam_r:
f1 = (f1 + 1) % 2
else:
lam = 0
return {'f1': f1, 'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam, 'lam_r': lam_r}
Expand Down Expand Up @@ -298,6 +295,12 @@ def build_composite_bloq(
y=y,
)

# Allocate an ancilla qubit that acts as a flag for the rare condition that the
# pre-computed lambda_r is equal to the calculated lambda. This ancilla is used to properly
# clear the f1 qubit when lambda is set to lambda_r.
ancilla = bb.allocate()
z4, lam_r, ancilla = bb.add(Equals(QMontgomeryUInt(self.n)), x=z4, y=lam_r, target=ancilla)

# If ctrl = 1 and x != a: lam = (y - b) / (x - a) % p.
z4_split = bb.split(z4)
lam_split = bb.split(lam)
Expand Down Expand Up @@ -325,7 +328,18 @@ def build_composite_bloq(
lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n))

# If lam = lam_r: return f1 = 0. (If not we will flip f1 to 0 at the end iff x_r = y_r = 0).
lam, lam_r, f1 = bb.add(Equals(QMontgomeryUInt(self.n)), x=lam, y=lam_r, target=f1)
# Only flip when lam is set to lam_r.
ancilla, lam, lam_r, f1 = bb.add(
Equals(QMontgomeryUInt(self.n)).controlled(ctrl_spec=CtrlSpec(cvs=0)),
ctrl=ancilla,
x=lam,
y=lam_r,
target=f1,
)

# Clear the ancilla bit and free it.
z4, lam_r, ancilla = bb.add(Equals(QMontgomeryUInt(self.n)), x=z4, y=lam_r, target=ancilla)
bb.free(ancilla)

# Uncompute the modular multiplication then the modular inversion.
x, y = bb.add(
Expand All @@ -345,7 +359,8 @@ def build_composite_bloq(

def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
return {
Equals(QMontgomeryUInt(self.n)): 1,
Equals(QMontgomeryUInt(self.n)): 2,
Equals(QMontgomeryUInt(self.n)).controlled(ctrl_spec=CtrlSpec(cvs=0)): 1,
ModSub(QMontgomeryUInt(self.n), mod=self.mod): 1,
CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1,
KaliskiModInverse(bitsize=self.n, mod=self.mod): 1,
Expand Down Expand Up @@ -654,6 +669,7 @@ class _ECAddStepFive(Bloq):
will contain the x component of the resultant curve point.
y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which
will contain the y component of the resultant curve point.
lam_r: The precomputed lambda slope used in the addition operation if (a, b) = (x, y) in montgomery form.
lam: The lambda slope used in the addition operation.

References:
Expand All @@ -674,6 +690,7 @@ def signature(self) -> 'Signature':
Register('b', QMontgomeryUInt(self.n)),
Register('x', QMontgomeryUInt(self.n)),
Register('y', QMontgomeryUInt(self.n)),
Register('lam_r', QMontgomeryUInt(self.n)),
Register('lam', QMontgomeryUInt(self.n), side=Side.LEFT),
]
)
Expand All @@ -685,14 +702,15 @@ def on_classical_vals(
b: 'ClassicalValT',
x: 'ClassicalValT',
y: 'ClassicalValT',
lam_r: 'ClassicalValT',
lam: 'ClassicalValT',
) -> Dict[str, 'ClassicalValT']:
if ctrl == 1:
x = (a - x) % self.mod
y = (y - b) % self.mod
else:
x = (x + a) % self.mod
return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y}
return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam_r': lam_r}

def build_composite_bloq(
self,
Expand All @@ -702,6 +720,7 @@ def build_composite_bloq(
b: Soquet,
x: Soquet,
y: Soquet,
lam_r: Soquet,
lam: Soquet,
) -> Dict[str, 'SoquetT']:
if is_symbolic(self.n):
Expand Down Expand Up @@ -731,9 +750,31 @@ def build_composite_bloq(
z4_split[i] = ctrls[1]
z4 = bb.join(z4_split, dtype=QMontgomeryUInt(self.n))
lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n))
# TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit where lambda
# is not set to 0 before being freed.
bb.add(Free(QMontgomeryUInt(self.n), dirty=True), reg=lam)

# If the denominator of lambda is 0, lam = lam_r so we clear lam with lam_r.
ancilla = bb.allocate()
x_split = bb.split(x)
x_split, ancilla = bb.add(
MultiControlX(cvs=[0] * int(self.n)), controls=x_split, target=ancilla
)
lam_r_split = bb.split(lam_r)
lam_split = bb.split(lam)
for i in range(int(self.n)):
ctrls = [ctrl, ancilla, lam_r_split[i]]
ctrls, lam_split[i] = bb.add(
MultiControlX(cvs=[1, 1, 1]), controls=ctrls, target=lam_split[i]
)
ctrl = ctrls[0]
ancilla = ctrls[1]
lam_r_split[i] = ctrls[2]
lam_r = bb.join(lam_r_split, dtype=QMontgomeryUInt(self.n))
lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n))
x_split, ancilla = bb.add(
MultiControlX(cvs=[0] * int(self.n)), controls=x_split, target=ancilla
)
x = bb.join(x_split, dtype=QMontgomeryUInt(self.n))
bb.free(ancilla)
bb.add(Free(QMontgomeryUInt(self.n)), reg=lam)

# Uncompute multiplication and inverse.
x, y = bb.add(
Expand All @@ -758,9 +799,14 @@ def build_composite_bloq(
ctrl, b, y = bb.add(CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=b, y=y)

# Return the output registers.
return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y}
return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam_r': lam_r}

def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
cvs: Union[list[int], HasLength]
if isinstance(self.n, int):
cvs = [0] * self.n
else:
cvs = HasLength(self.n)
return {
CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1,
KaliskiModInverse(bitsize=self.n, mod=self.mod): 1,
Expand All @@ -773,6 +819,8 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
KaliskiModInverse(bitsize=self.n, mod=self.mod).adjoint(): 1,
ModAdd(self.n, mod=self.mod): 1,
MultiControlX(cvs=[1, 1]): self.n,
MultiControlX(cvs=cvs): 2,
MultiControlX(cvs=[1, 1, 1]): self.n,
CModNeg(QMontgomeryUInt(self.n), mod=self.mod): 1,
}

Expand Down Expand Up @@ -865,6 +913,21 @@ def build_composite_bloq(
f3 = f_ctrls[1]
f4 = f_ctrls[2]

# Unset f2 if ((a, b) = (0, 0) AND y = 0) OR ((x, y) = (0, 0) AND b = 0).
aby_arr = np.concatenate([bb.split(a), bb.split(b), bb.split(y)])
aby_arr, f2 = bb.add(MultiControlX(cvs=[0] * 3 * self.n), controls=aby_arr, target=f2)
aby_arr = np.split(aby_arr, 3)
a = bb.join(aby_arr[0], dtype=QMontgomeryUInt(self.n))
b = bb.join(aby_arr[1], dtype=QMontgomeryUInt(self.n))
y = bb.join(aby_arr[2], dtype=QMontgomeryUInt(self.n))

xyb_arr = np.concatenate([bb.split(x), bb.split(y), bb.split(b)])
xyb_arr, f2 = bb.add(MultiControlX(cvs=[0] * 3 * self.n), controls=xyb_arr, target=f2)
xyb_arr = np.split(xyb_arr, 3)
x = bb.join(xyb_arr[0], dtype=QMontgomeryUInt(self.n))
y = bb.join(xyb_arr[1], dtype=QMontgomeryUInt(self.n))
b = bb.join(xyb_arr[2], dtype=QMontgomeryUInt(self.n))
Comment on lines +917 to +929
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can replace this with the default controlled to avoid manual splits and joins:

mcx = XGate().controlled(CtrlSpec(qdtypes=QMontgomeryUInt(self.n), cvs=[0, 0, 0]))
[a, b, y], f2 = bb.add(mcx, ctrl=[a, b, y], q=f2)
[x, y, b], f2 = bb.add(mcx, ctrl=[x, y, b], q=f2)

Though I suspect the types may not be propagated correctly yet. In case you try the above suggestion and it fails, could you please open an issue?

p.s. this would also enable decomposing for symbolic self.n which would be an added benefit.


# Set (x, y) to (a, b) if f4 is set.
a_split = bb.split(a)
x_split = bb.split(x)
Expand All @@ -885,24 +948,6 @@ def build_composite_bloq(
b = bb.join(b_split, QMontgomeryUInt(self.n))
y = bb.join(y_split, QMontgomeryUInt(self.n))

# Unset f4 if (x, y) = (a, b).
ab = bb.join(np.concatenate([bb.split(a), bb.split(b)]), dtype=QMontgomeryUInt(2 * self.n))
xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n))
ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4)
ab_split = bb.split(ab)
a = bb.join(ab_split[: int(self.n)], dtype=QMontgomeryUInt(self.n))
b = bb.join(ab_split[int(self.n) :], dtype=QMontgomeryUInt(self.n))
xy_split = bb.split(xy)
x = bb.join(xy_split[: int(self.n)], dtype=QMontgomeryUInt(self.n))
y = bb.join(xy_split[int(self.n) :], dtype=QMontgomeryUInt(self.n))

# Unset f3 if (a, b) = (0, 0).
ab_arr = np.concatenate([bb.split(a), bb.split(b)])
ab_arr, f3 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=ab_arr, target=f3)
ab_arr = np.split(ab_arr, 2)
a = bb.join(ab_arr[0], dtype=QMontgomeryUInt(self.n))
b = bb.join(ab_arr[1], dtype=QMontgomeryUInt(self.n))

# If f1 and f2 are set, subtract a from x and add b to y.
ancilla = bb.add(ZeroState())
toff_ctrl = [f1, f2]
Expand All @@ -925,6 +970,24 @@ def build_composite_bloq(
f2 = toff_ctrl[1]
bb.add(Free(QBit()), reg=ancilla)

# Unset f4 if (x, y) = (a, b).
ab = bb.join(np.concatenate([bb.split(a), bb.split(b)]), dtype=QMontgomeryUInt(2 * self.n))
xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n))
ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4)
ab_split = bb.split(ab)
a = bb.join(ab_split[: int(self.n)], dtype=QMontgomeryUInt(self.n))
b = bb.join(ab_split[int(self.n) :], dtype=QMontgomeryUInt(self.n))
xy_split = bb.split(xy)
x = bb.join(xy_split[: int(self.n)], dtype=QMontgomeryUInt(self.n))
y = bb.join(xy_split[int(self.n) :], dtype=QMontgomeryUInt(self.n))

# Unset f3 if (a, b) = (0, 0).
ab_arr = np.concatenate([bb.split(a), bb.split(b)])
ab_arr, f3 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=ab_arr, target=f3)
ab_arr = np.split(ab_arr, 2)
a = bb.join(ab_arr[0], dtype=QMontgomeryUInt(self.n))
b = bb.join(ab_arr[1], dtype=QMontgomeryUInt(self.n))

# Unset f1 and f2 if (x, y) = (0, 0).
xy_arr = np.concatenate([bb.split(x), bb.split(y)])
xy_arr, junk, out = bb.add(MultiAnd(cvs=[0] * 2 * self.n), ctrl=xy_arr)
Expand All @@ -941,33 +1004,35 @@ def build_composite_bloq(
y = bb.join(xy_arr[1], dtype=QMontgomeryUInt(self.n))

# Free all ancilla qubits in the zero state.
# TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bugs in circuit where f1,
# f2, and f4 are freed before being set to 0.
bb.add(Free(QBit(), dirty=True), reg=f1)
bb.add(Free(QBit(), dirty=True), reg=f2)
bb.add(Free(QBit()), reg=f1)
bb.add(Free(QBit()), reg=f2)
bb.add(Free(QBit()), reg=f3)
bb.add(Free(QBit(), dirty=True), reg=f4)
bb.add(Free(QBit()), reg=f4)
bb.add(Free(QBit()), reg=ctrl)

# Return the output registers.
return {'a': a, 'b': b, 'x': x, 'y': y}

def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
cvs: Union[list[int], HasLength]
cvs2: Union[list[int], HasLength]
cvs3: Union[list[int], HasLength]
if isinstance(self.n, int):
cvs = [0] * 2 * self.n
cvs2 = [0] * 2 * self.n
cvs3 = [0] * 3 * self.n
else:
cvs = HasLength(2 * self.n)
cvs2 = HasLength(2 * self.n)
cvs3 = HasLength(3 * self.n)
return {
MultiControlX(cvs=cvs): 1,
MultiControlX(cvs=cvs2): 1,
MultiControlX(cvs=cvs3): 2,
MultiControlX(cvs=[0] * 3): 1,
CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1,
CModAdd(QMontgomeryUInt(self.n), mod=self.mod): 1,
Toffoli(): 2 * self.n + 4,
Equals(QMontgomeryUInt(2 * self.n)): 1,
MultiAnd(cvs=cvs): 1,
MultiAnd(cvs=cvs2): 1,
MultiTargetCNOT(2): 1,
MultiAnd(cvs=cvs).adjoint(): 1,
MultiAnd(cvs=cvs2).adjoint(): 1,
}


Expand Down Expand Up @@ -1046,13 +1111,14 @@ def build_composite_bloq(
x, y, lam = bb.add(
_ECAddStepFour(n=self.n, mod=self.mod, window_size=self.window_size), x=x, y=y, lam=lam
)
ctrl, a, b, x, y = bb.add(
ctrl, a, b, x, y, lam_r = bb.add(
_ECAddStepFive(n=self.n, mod=self.mod, window_size=self.window_size),
ctrl=ctrl,
a=a,
b=b,
x=x,
y=y,
lam_r=lam_r,
lam=lam,
)
a, b, x, y = bb.add(
Expand Down
14 changes: 10 additions & 4 deletions qualtran/bloqs/factoring/ecc/ec_add_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def test_ec_add_steps_classical_fast(n, m, a, b, x, y):
b=step_3['b'],
x=step_4['x'],
y=step_4['y'],
lam_r=step_2['lam_r'],
lam=step_4['lam'],
)
ret2 = bloq.decompose_bloq().call_classically(
Expand All @@ -118,6 +119,7 @@ def test_ec_add_steps_classical_fast(n, m, a, b, x, y):
b=step_3['b'],
x=step_4['x'],
y=step_4['y'],
lam_r=step_2['lam_r'],
lam=step_4['lam'],
)
assert ret1 == ret2
Expand All @@ -128,6 +130,7 @@ def test_ec_add_steps_classical_fast(n, m, a, b, x, y):
b=step_3['b'],
x=step_4['x'],
y=step_4['y'],
lam_r=step_2['lam_r'],
lam=step_4['lam'],
)
bloq = _ECAddStepSix(n=n, mod=p)
Expand Down Expand Up @@ -250,6 +253,7 @@ def test_ec_add_steps_classical(n, m, a, b, x, y):
b=step_3['b'],
x=step_4['x'],
y=step_4['y'],
lam_r=step_2['lam_r'],
lam=step_4['lam'],
)
ret2 = bloq.decompose_bloq().call_classically(
Expand All @@ -258,6 +262,7 @@ def test_ec_add_steps_classical(n, m, a, b, x, y):
b=step_3['b'],
x=step_4['x'],
y=step_4['y'],
lam_r=step_2['lam_r'],
lam=step_4['lam'],
)
assert ret1 == ret2
Expand All @@ -268,6 +273,7 @@ def test_ec_add_steps_classical(n, m, a, b, x, y):
b=step_3['b'],
x=step_4['x'],
y=step_4['y'],
lam_r=step_2['lam_r'],
lam=step_4['lam'],
)
bloq = _ECAddStepSix(n=n, mod=p)
Expand Down Expand Up @@ -413,12 +419,12 @@ def test_ec_add_symbolic_cost():

# Litinski 2023 https://arxiv.org/abs/2306.08585
# Based on the counts from Figures 3, 5, and 8 the toffoli count for ECAdd is 126.5n^2 + 189n.
# The following formula is 126.5n^2 + 195.5n - 31. We account for the discrepancy in the
# The following formula is 126.5n^2 + 217.5n - 36. We account for the discrepancy in the
# coefficient of n by a reduction in the toffoli cost of Montgomery ModMult, an increase in the
# toffoli cost for Kaliski Mod Inverse, n extra toffolis in ModNeg, 2n extra toffolis to do n
# 3-controlled toffolis in step 2. The expression is written with rationals because sympy
# comparison fails with floats.
assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(407, 2) * n - 31
# 3-controlled toffolis in step 2, and a few extra gates added to fix bugs found in the
# circuit. The expression is written with rationals because sympy comparison fails with floats.
assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(435, 2) * n - 36


def test_ec_add(bloq_autotester):
Expand Down