Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add secret check for blinded-commitment #126

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
10 changes: 10 additions & 0 deletions blind_and_swap/blind_and_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
hash_to_int = lambda x: int.from_bytes(sha256(x).digest(), 'little')

from py_ecc import bn128 as curve
from poly_utils import PrimeField

POINT = tuple
SIG2 = tuple
SIG3 = tuple
Expand Down Expand Up @@ -125,18 +127,26 @@ def test():
assert verify_1of2(b'cow', KEY1, KEY2, secondof2_sig, BASE)
print("Passed 1 of 2 signature test")
# Blind and swap proofs
# Create two secrets of commitments
f = PrimeField(curve.curve_order)
x1 = f.mul(69042, f.inv(31337))
x2 = f.mul(299792458, f.inv(8675309))
A1, B1, A2, B2 = (curve.multiply(curve.G1, x) for x in (31337, 69042, 8675309, 299792458))
factor = 5
C1, D1, C2, D2, proof = prove_blind_and_swap(A1, B1, A2, B2, factor, False)
FAKE_POINT = curve.multiply(curve.G1, 98765432123456789)
assert (C1, D1, C2, D2) == tuple(curve.multiply(P, factor) for P in (A1, B1, A2, B2))
assert verify_blind_and_swap(A1, B1, A2, B2, C1, D1, C2, D2, proof)
assert not verify_blind_and_swap(A1, B1, A2, B2, C1, FAKE_POINT, C2, D2, proof)
assert curve.multiply(C1, x1) == D1
assert curve.multiply(C2, x2) == D2
factor2 = 7
E1, F1, E2, F2, proof = prove_blind_and_swap(C1, D1, C2, D2, factor2, True)
assert (E1, F1, E2, F2) == tuple(curve.multiply(P, factor2) for P in (C2, D2, C1, D1))
assert verify_blind_and_swap(C1, D1, C2, D2, E1, F1, E2, F2, proof)
assert not verify_blind_and_swap(C1, D1, C2, D2, E1, F1, E2, FAKE_POINT, proof)
assert curve.multiply(E2, x1) == F2
assert curve.multiply(E1, x2) == F1
print("Passed blind-and-swap test")

if __name__ == '__main__':
Expand Down
207 changes: 207 additions & 0 deletions blind_and_swap/poly_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Creates an object that includes convenience operations for numbers
# and polynomials in some prime field
class PrimeField():
def __init__(self, modulus):
assert pow(2, modulus, modulus) == 2
self.modulus = modulus

def add(self, x, y):
return (x+y) % self.modulus

def sub(self, x, y):
return (x-y) % self.modulus

def mul(self, x, y):
return (x*y) % self.modulus

def exp(self, x, p):
return pow(x, p, self.modulus)

# Modular inverse using the extended Euclidean algorithm
def inv(self, a):
if a == 0:
return 0
lm, hm = 1, 0
low, high = a % self.modulus, self.modulus
while low > 1:
r = high//low
nm, new = hm-lm*r, high-low*r
lm, low, hm, high = nm, new, lm, low
return lm % self.modulus

def multi_inv(self, values):
partials = [1]
for i in range(len(values)):
partials.append(self.mul(partials[-1], values[i] or 1))
inv = self.inv(partials[-1])
outputs = [0] * len(values)
for i in range(len(values), 0, -1):
outputs[i-1] = self.mul(partials[i-1], inv) if values[i-1] else 0
inv = self.mul(inv, values[i-1] or 1)
return outputs

def div(self, x, y):
return self.mul(x, self.inv(y))

# Evaluate a polynomial at a point
def eval_poly_at(self, p, x):
y = 0
power_of_x = 1
for i, p_coeff in enumerate(p):
y += power_of_x * p_coeff
power_of_x = (power_of_x * x) % self.modulus
return y % self.modulus

# Arithmetic for polynomials
def add_polys(self, a, b):
return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0))
% self.modulus for i in range(max(len(a), len(b)))]

def sub_polys(self, a, b):
return [((a[i] if i < len(a) else 0) - (b[i] if i < len(b) else 0))
% self.modulus for i in range(max(len(a), len(b)))]

def mul_by_const(self, a, c):
return [(x*c) % self.modulus for x in a]

def mul_polys(self, a, b):
o = [0] * (len(a) + len(b) - 1)
for i, aval in enumerate(a):
for j, bval in enumerate(b):
o[i+j] += a[i] * b[j]
return [x % self.modulus for x in o]

def div_polys(self, a, b):
assert len(a) >= len(b)
a = [x for x in a]
o = []
apos = len(a) - 1
bpos = len(b) - 1
diff = apos - bpos
while diff >= 0:
quot = self.div(a[apos], b[bpos])
o.insert(0, quot)
for i in range(bpos, -1, -1):
a[diff+i] -= b[i] * quot
apos -= 1
diff -= 1
return [x % self.modulus for x in o]

def mod_polys(self, a, b):
return self.sub_polys(a, self.mul_polys(b, self.div_polys(a, b)))[:len(b)-1]

# Build a polynomial from a few coefficients
def sparse(self, coeff_dict):
o = [0] * (max(coeff_dict.keys()) + 1)
for k, v in coeff_dict.items():
o[k] = v % self.modulus
return o

# Build a polynomial that returns 0 at all specified xs
def zpoly(self, xs):
root = [1]
for x in xs:
root.insert(0, 0)
for j in range(len(root)-1):
root[j] -= root[j+1] * x
return [x % self.modulus for x in root]

# Given p+1 y values and x values with no errors, recovers the original
# p+1 degree polynomial.
# Lagrange interpolation works roughly in the following way.
# 1. Suppose you have a set of points, eg. x = [1, 2, 3], y = [2, 5, 10]
# 2. For each x, generate a polynomial which equals its corresponding
# y coordinate at that point and 0 at all other points provided.
# 3. Add these polynomials together.

def lagrange_interp(self, xs, ys):
# Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn)
root = self.zpoly(xs)
assert len(root) == len(ys) + 1
# print(root)
# Generate per-value numerator polynomials, eg. for x=x2,
# (x - x1) * (x - x3) * ... * (x - xn), by dividing the master
# polynomial back by each x coordinate
nums = [self.div_polys(root, [-x, 1]) for x in xs]
# Generate denominators by evaluating numerator polys at each x
denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))]
invdenoms = self.multi_inv(denoms)
# Generate output polynomial, which is the sum of the per-value numerator
# polynomials rescaled to have the right y values
b = [0 for y in ys]
for i in range(len(xs)):
yslice = self.mul(ys[i], invdenoms[i])
for j in range(len(ys)):
if nums[i][j] and ys[i]:
b[j] += nums[i][j] * yslice
return [x % self.modulus for x in b]

# Optimized poly evaluation for degree 4
def eval_quartic(self, p, x):
xsq = x * x % self.modulus
xcb = xsq * x
return (p[0] + p[1] * x + p[2] * xsq + p[3] * xcb) % self.modulus

# Optimized version of the above restricted to deg-4 polynomials
def lagrange_interp_4(self, xs, ys):
x01, x02, x03, x12, x13, x23 = \
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
m = self.modulus
eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
e0 = self.eval_poly_at(eq0, xs[0])
e1 = self.eval_poly_at(eq1, xs[1])
e2 = self.eval_poly_at(eq2, xs[2])
e3 = self.eval_poly_at(eq3, xs[3])
e01 = e0 * e1
e23 = e2 * e3
invall = self.inv(e01 * e23)
inv_y0 = ys[0] * invall * e1 * e23 % m
inv_y1 = ys[1] * invall * e0 * e23 % m
inv_y2 = ys[2] * invall * e01 * e3 % m
inv_y3 = ys[3] * invall * e01 * e2 % m
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)]

# Optimized version of the above restricted to deg-2 polynomials
def lagrange_interp_2(self, xs, ys):
m = self.modulus
eq0 = [-xs[1] % m, 1]
eq1 = [-xs[0] % m, 1]
e0 = self.eval_poly_at(eq0, xs[0])
e1 = self.eval_poly_at(eq1, xs[1])
invall = self.inv(e0 * e1)
inv_y0 = ys[0] * invall * e1
inv_y1 = ys[1] * invall * e0
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % m for i in range(2)]

# Optimized version of the above restricted to deg-4 polynomials
def multi_interp_4(self, xsets, ysets):
data = []
invtargets = []
for xs, ys in zip(xsets, ysets):
x01, x02, x03, x12, x13, x23 = \
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
m = self.modulus
eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
e0 = self.eval_quartic(eq0, xs[0])
e1 = self.eval_quartic(eq1, xs[1])
e2 = self.eval_quartic(eq2, xs[2])
e3 = self.eval_quartic(eq3, xs[3])
data.append([ys, eq0, eq1, eq2, eq3])
invtargets.extend([e0, e1, e2, e3])
invalls = self.multi_inv(invtargets)
o = []
for (i, (ys, eq0, eq1, eq2, eq3)) in enumerate(data):
invallz = invalls[i*4:i*4+4]
inv_y0 = ys[0] * invallz[0] % m
inv_y1 = ys[1] * invallz[1] % m
inv_y2 = ys[2] * invallz[2] % m
inv_y3 = ys[3] * invallz[3] % m
o.append([(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)])
# assert o == [self.lagrange_interp_4(xs, ys) for xs, ys in zip(xsets, ysets)]
return o
64 changes: 44 additions & 20 deletions erasure_code/2d_recovery/recover.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
import sys
import copy
import seaborn as sns
import matplotlib.pyplot as plt
import random

def mkmatrix(rows, cols):
return [[0 for _ in range(cols)] for _ in range(rows)]

def mkmatrix(rows, cols, default_value=0):
return [[default_value for _ in range(cols)] for _ in range(rows)]

def print_form(mat):
return '\n'.join(''.join(str(val) for val in row) for row in mat)

def recover(matrix):
def recover(matrix, show_plot=True):
rows, cols = len(matrix), len(matrix[0])
matrix = copy.deepcopy(matrix)
print(print_form(matrix))
if show_plot:
sns.heatmap(matrix, vmin=0, vmax=1, linewidth=0.5, cbar=False)
plt.pause(3)
for _round in range(1, rows + cols + 1):
print(f"\nRound {_round}")

rows_to_recover = {
i for i in range(rows) if
cols <= sum(matrix[i]) * 2 < cols * 2
Expand All @@ -31,11 +39,18 @@ def recover(matrix):
for i in range(rows):
matrix[i][col] = 1
print(print_form(matrix))
if show_plot:
sns.heatmap(matrix, vmin=0, vmax=1, linewidth=0.5, cbar=False)
plt.pause(0.2)
if sum(sum(row) for row in matrix) == rows * cols:
print(f"Finished in {_round} rounds")
if show_plot:
plt.show()
return _round
if rows_to_recover == cols_to_recover == set():
print("Recovery failed")
if show_plot:
plt.show()
return None
raise Exception("wtf happened here ^_^")

Expand All @@ -45,24 +60,33 @@ def parse(text):
rows = text.strip().split(separator)
return [[int(x) for x in row] for row in rows]

def mk_evil_matrix(n):
odd_n = n - ((n+1) % 2)
half_odd_n = odd_n // 2
o = mkmatrix(n, n)
for i in range(half_odd_n):
for j in range(half_odd_n + i, odd_n):
o[i][j] = 1
o[half_odd_n][half_odd_n] = 1
for i in range(1, half_odd_n+1):
for j in range(i):
o[half_odd_n + i][j] = 1
def mk_evil_matrix(n, error_alg='default', show_plot=False):
if error_alg == 'default':
odd_n = n - ((n+1) % 2)
half_odd_n = odd_n // 2
o = mkmatrix(n, n)
for i in range(half_odd_n):
for j in range(half_odd_n + i, odd_n):
o[i][j] = 1
o[half_odd_n][half_odd_n] = 1
for i in range(1, half_odd_n+1):
for j in range(i):
o[half_odd_n + i][j] = 1
elif error_alg[:4] == 'rand':
o = mkmatrix(n, n, 1)
n_corrupted = int(error_alg[4:])
# corrupt/withhold the samples with EXACT number
for i in range(n_corrupted):
while True:
r = random.randint(0, n - 1)
c = random.randint(0, n - 1)
if o[r][c] == 1:
o[r][c] = 0
break
return o

def test(n=12):
recover(mk_evil_matrix(n))
def test(n=12, error_alg='default'):
recover(mk_evil_matrix(n, error_alg=error_alg))

if __name__ == '__main__':
if len(sys.argv) == 2:
test(int(sys.argv[-1]))
else:
test()
test(12 if len(sys.argv) < 2 else int(sys.argv[1]), error_alg='default' if len(sys.argv) < 3 else sys.argv[2])
Loading