diff --git a/blind_and_swap/blind_and_swap.py b/blind_and_swap/blind_and_swap.py index e43d3b2b..d9a0994d 100644 --- a/blind_and_swap/blind_and_swap.py +++ b/blind_and_swap/blind_and_swap.py @@ -4,6 +4,8 @@ hash_to_int = lambda x: int.from_bytes(sha256(x).digest(), 'little') from py_ecc import bn128 as curve +from poly_utils import PrimeField + POINT = tuple SIG2 = tuple SIG3 = tuple @@ -125,6 +127,10 @@ def test(): assert verify_1of2(b'cow', KEY1, KEY2, secondof2_sig, BASE) print("Passed 1 of 2 signature test") # Blind and swap proofs + # Create two secrets of commitments + f = PrimeField(curve.curve_order) + x1 = f.mul(69042, f.inv(31337)) + x2 = f.mul(299792458, f.inv(8675309)) A1, B1, A2, B2 = (curve.multiply(curve.G1, x) for x in (31337, 69042, 8675309, 299792458)) factor = 5 C1, D1, C2, D2, proof = prove_blind_and_swap(A1, B1, A2, B2, factor, False) @@ -132,11 +138,15 @@ def test(): assert (C1, D1, C2, D2) == tuple(curve.multiply(P, factor) for P in (A1, B1, A2, B2)) assert verify_blind_and_swap(A1, B1, A2, B2, C1, D1, C2, D2, proof) assert not verify_blind_and_swap(A1, B1, A2, B2, C1, FAKE_POINT, C2, D2, proof) + assert curve.multiply(C1, x1) == D1 + assert curve.multiply(C2, x2) == D2 factor2 = 7 E1, F1, E2, F2, proof = prove_blind_and_swap(C1, D1, C2, D2, factor2, True) assert (E1, F1, E2, F2) == tuple(curve.multiply(P, factor2) for P in (C2, D2, C1, D1)) assert verify_blind_and_swap(C1, D1, C2, D2, E1, F1, E2, F2, proof) assert not verify_blind_and_swap(C1, D1, C2, D2, E1, F1, E2, FAKE_POINT, proof) + assert curve.multiply(E2, x1) == F2 + assert curve.multiply(E1, x2) == F1 print("Passed blind-and-swap test") if __name__ == '__main__': diff --git a/blind_and_swap/poly_utils.py b/blind_and_swap/poly_utils.py new file mode 100644 index 00000000..c29d1b51 --- /dev/null +++ b/blind_and_swap/poly_utils.py @@ -0,0 +1,207 @@ +# Creates an object that includes convenience operations for numbers +# and polynomials in some prime field +class PrimeField(): + def __init__(self, modulus): + assert pow(2, modulus, modulus) == 2 + self.modulus = modulus + + def add(self, x, y): + return (x+y) % self.modulus + + def sub(self, x, y): + return (x-y) % self.modulus + + def mul(self, x, y): + return (x*y) % self.modulus + + def exp(self, x, p): + return pow(x, p, self.modulus) + + # Modular inverse using the extended Euclidean algorithm + def inv(self, a): + if a == 0: + return 0 + lm, hm = 1, 0 + low, high = a % self.modulus, self.modulus + while low > 1: + r = high//low + nm, new = hm-lm*r, high-low*r + lm, low, hm, high = nm, new, lm, low + return lm % self.modulus + + def multi_inv(self, values): + partials = [1] + for i in range(len(values)): + partials.append(self.mul(partials[-1], values[i] or 1)) + inv = self.inv(partials[-1]) + outputs = [0] * len(values) + for i in range(len(values), 0, -1): + outputs[i-1] = self.mul(partials[i-1], inv) if values[i-1] else 0 + inv = self.mul(inv, values[i-1] or 1) + return outputs + + def div(self, x, y): + return self.mul(x, self.inv(y)) + + # Evaluate a polynomial at a point + def eval_poly_at(self, p, x): + y = 0 + power_of_x = 1 + for i, p_coeff in enumerate(p): + y += power_of_x * p_coeff + power_of_x = (power_of_x * x) % self.modulus + return y % self.modulus + + # Arithmetic for polynomials + def add_polys(self, a, b): + return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0)) + % self.modulus for i in range(max(len(a), len(b)))] + + def sub_polys(self, a, b): + return [((a[i] if i < len(a) else 0) - (b[i] if i < len(b) else 0)) + % self.modulus for i in range(max(len(a), len(b)))] + + def mul_by_const(self, a, c): + return [(x*c) % self.modulus for x in a] + + def mul_polys(self, a, b): + o = [0] * (len(a) + len(b) - 1) + for i, aval in enumerate(a): + for j, bval in enumerate(b): + o[i+j] += a[i] * b[j] + return [x % self.modulus for x in o] + + def div_polys(self, a, b): + assert len(a) >= len(b) + a = [x for x in a] + o = [] + apos = len(a) - 1 + bpos = len(b) - 1 + diff = apos - bpos + while diff >= 0: + quot = self.div(a[apos], b[bpos]) + o.insert(0, quot) + for i in range(bpos, -1, -1): + a[diff+i] -= b[i] * quot + apos -= 1 + diff -= 1 + return [x % self.modulus for x in o] + + def mod_polys(self, a, b): + return self.sub_polys(a, self.mul_polys(b, self.div_polys(a, b)))[:len(b)-1] + + # Build a polynomial from a few coefficients + def sparse(self, coeff_dict): + o = [0] * (max(coeff_dict.keys()) + 1) + for k, v in coeff_dict.items(): + o[k] = v % self.modulus + return o + + # Build a polynomial that returns 0 at all specified xs + def zpoly(self, xs): + root = [1] + for x in xs: + root.insert(0, 0) + for j in range(len(root)-1): + root[j] -= root[j+1] * x + return [x % self.modulus for x in root] + + # Given p+1 y values and x values with no errors, recovers the original + # p+1 degree polynomial. + # Lagrange interpolation works roughly in the following way. + # 1. Suppose you have a set of points, eg. x = [1, 2, 3], y = [2, 5, 10] + # 2. For each x, generate a polynomial which equals its corresponding + # y coordinate at that point and 0 at all other points provided. + # 3. Add these polynomials together. + + def lagrange_interp(self, xs, ys): + # Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn) + root = self.zpoly(xs) + assert len(root) == len(ys) + 1 + # print(root) + # Generate per-value numerator polynomials, eg. for x=x2, + # (x - x1) * (x - x3) * ... * (x - xn), by dividing the master + # polynomial back by each x coordinate + nums = [self.div_polys(root, [-x, 1]) for x in xs] + # Generate denominators by evaluating numerator polys at each x + denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))] + invdenoms = self.multi_inv(denoms) + # Generate output polynomial, which is the sum of the per-value numerator + # polynomials rescaled to have the right y values + b = [0 for y in ys] + for i in range(len(xs)): + yslice = self.mul(ys[i], invdenoms[i]) + for j in range(len(ys)): + if nums[i][j] and ys[i]: + b[j] += nums[i][j] * yslice + return [x % self.modulus for x in b] + + # Optimized poly evaluation for degree 4 + def eval_quartic(self, p, x): + xsq = x * x % self.modulus + xcb = xsq * x + return (p[0] + p[1] * x + p[2] * xsq + p[3] * xcb) % self.modulus + + # Optimized version of the above restricted to deg-4 polynomials + def lagrange_interp_4(self, xs, ys): + x01, x02, x03, x12, x13, x23 = \ + xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3] + m = self.modulus + eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1] + eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1] + eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1] + eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1] + e0 = self.eval_poly_at(eq0, xs[0]) + e1 = self.eval_poly_at(eq1, xs[1]) + e2 = self.eval_poly_at(eq2, xs[2]) + e3 = self.eval_poly_at(eq3, xs[3]) + e01 = e0 * e1 + e23 = e2 * e3 + invall = self.inv(e01 * e23) + inv_y0 = ys[0] * invall * e1 * e23 % m + inv_y1 = ys[1] * invall * e0 * e23 % m + inv_y2 = ys[2] * invall * e01 * e3 % m + inv_y3 = ys[3] * invall * e01 * e2 % m + return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)] + + # Optimized version of the above restricted to deg-2 polynomials + def lagrange_interp_2(self, xs, ys): + m = self.modulus + eq0 = [-xs[1] % m, 1] + eq1 = [-xs[0] % m, 1] + e0 = self.eval_poly_at(eq0, xs[0]) + e1 = self.eval_poly_at(eq1, xs[1]) + invall = self.inv(e0 * e1) + inv_y0 = ys[0] * invall * e1 + inv_y1 = ys[1] * invall * e0 + return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % m for i in range(2)] + + # Optimized version of the above restricted to deg-4 polynomials + def multi_interp_4(self, xsets, ysets): + data = [] + invtargets = [] + for xs, ys in zip(xsets, ysets): + x01, x02, x03, x12, x13, x23 = \ + xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3] + m = self.modulus + eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1] + eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1] + eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1] + eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1] + e0 = self.eval_quartic(eq0, xs[0]) + e1 = self.eval_quartic(eq1, xs[1]) + e2 = self.eval_quartic(eq2, xs[2]) + e3 = self.eval_quartic(eq3, xs[3]) + data.append([ys, eq0, eq1, eq2, eq3]) + invtargets.extend([e0, e1, e2, e3]) + invalls = self.multi_inv(invtargets) + o = [] + for (i, (ys, eq0, eq1, eq2, eq3)) in enumerate(data): + invallz = invalls[i*4:i*4+4] + inv_y0 = ys[0] * invallz[0] % m + inv_y1 = ys[1] * invallz[1] % m + inv_y2 = ys[2] * invallz[2] % m + inv_y3 = ys[3] * invallz[3] % m + o.append([(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)]) + # assert o == [self.lagrange_interp_4(xs, ys) for xs, ys in zip(xsets, ysets)] + return o diff --git a/erasure_code/2d_recovery/recover.py b/erasure_code/2d_recovery/recover.py index b5e6c502..02634e95 100644 --- a/erasure_code/2d_recovery/recover.py +++ b/erasure_code/2d_recovery/recover.py @@ -1,18 +1,26 @@ import sys import copy +import seaborn as sns +import matplotlib.pyplot as plt +import random -def mkmatrix(rows, cols): - return [[0 for _ in range(cols)] for _ in range(rows)] + +def mkmatrix(rows, cols, default_value=0): + return [[default_value for _ in range(cols)] for _ in range(rows)] def print_form(mat): return '\n'.join(''.join(str(val) for val in row) for row in mat) -def recover(matrix): +def recover(matrix, show_plot=True): rows, cols = len(matrix), len(matrix[0]) matrix = copy.deepcopy(matrix) print(print_form(matrix)) + if show_plot: + sns.heatmap(matrix, vmin=0, vmax=1, linewidth=0.5, cbar=False) + plt.pause(3) for _round in range(1, rows + cols + 1): print(f"\nRound {_round}") + rows_to_recover = { i for i in range(rows) if cols <= sum(matrix[i]) * 2 < cols * 2 @@ -31,11 +39,18 @@ def recover(matrix): for i in range(rows): matrix[i][col] = 1 print(print_form(matrix)) + if show_plot: + sns.heatmap(matrix, vmin=0, vmax=1, linewidth=0.5, cbar=False) + plt.pause(0.2) if sum(sum(row) for row in matrix) == rows * cols: print(f"Finished in {_round} rounds") + if show_plot: + plt.show() return _round if rows_to_recover == cols_to_recover == set(): print("Recovery failed") + if show_plot: + plt.show() return None raise Exception("wtf happened here ^_^") @@ -45,24 +60,33 @@ def parse(text): rows = text.strip().split(separator) return [[int(x) for x in row] for row in rows] -def mk_evil_matrix(n): - odd_n = n - ((n+1) % 2) - half_odd_n = odd_n // 2 - o = mkmatrix(n, n) - for i in range(half_odd_n): - for j in range(half_odd_n + i, odd_n): - o[i][j] = 1 - o[half_odd_n][half_odd_n] = 1 - for i in range(1, half_odd_n+1): - for j in range(i): - o[half_odd_n + i][j] = 1 +def mk_evil_matrix(n, error_alg='default', show_plot=False): + if error_alg == 'default': + odd_n = n - ((n+1) % 2) + half_odd_n = odd_n // 2 + o = mkmatrix(n, n) + for i in range(half_odd_n): + for j in range(half_odd_n + i, odd_n): + o[i][j] = 1 + o[half_odd_n][half_odd_n] = 1 + for i in range(1, half_odd_n+1): + for j in range(i): + o[half_odd_n + i][j] = 1 + elif error_alg[:4] == 'rand': + o = mkmatrix(n, n, 1) + n_corrupted = int(error_alg[4:]) + # corrupt/withhold the samples with EXACT number + for i in range(n_corrupted): + while True: + r = random.randint(0, n - 1) + c = random.randint(0, n - 1) + if o[r][c] == 1: + o[r][c] = 0 + break return o -def test(n=12): - recover(mk_evil_matrix(n)) +def test(n=12, error_alg='default'): + recover(mk_evil_matrix(n, error_alg=error_alg)) if __name__ == '__main__': - if len(sys.argv) == 2: - test(int(sys.argv[-1])) - else: - test() + test(12 if len(sys.argv) < 2 else int(sys.argv[1]), error_alg='default' if len(sys.argv) < 3 else sys.argv[2]) diff --git a/rollup_compression/dicts.py b/rollup_compression/dicts.py index 35c47642..e9b97f9e 100644 --- a/rollup_compression/dicts.py +++ b/rollup_compression/dicts.py @@ -21,15 +21,27 @@ def zipfs_law_sample(maximum): return int(math.exp(random.random() * math.log(maximum + 1))) - 1 # Evaluates a given algorithm for updating the dictionary -def evaluate_dict_algorithm(Algo): - algo = Algo() - hits = 0 - for _ in range(SAMPLE_COUNT): - v = zipfs_law_sample(VOCAB_SIZE) - if algo.in_dict(v): - hits += 1 - algo.process(v) - return hits +def evaluate_dict_algorithm(Algo, trials=1, trace_interval=0): + hits_trace = [] + total_hits = 0 + for trial in range(trials): + algo = Algo() + hits = 0 + for sample in range(SAMPLE_COUNT): + v = zipfs_law_sample(VOCAB_SIZE) + if algo.in_dict(v): + hits += 1 + if trace_interval !=0 and sample % trace_interval == 0: + if trial == 0: + hits_trace.append(hits) + else: + hits_trace[sample // trace_interval] += hits + algo.process(v) + total_hits += hits + if trace_interval == 0: + return total_hits + else: + return total_hits, hits_trace # Stores the frequency table. The dictionary is the top N addresses in # the frequency table. This is expensive! @@ -47,6 +59,80 @@ def in_dict(self, v): def process(self, v): self.freqs[v] = self.freqs.get(v, 0) + 1 +# Stores the frequency table. The dictionary is the top N addresses in +# the frequency table using a heap (with sift limit). +class AlgoHeap(): + def __init__(self, sift_limit=64): + self.freqs = {} + self.dict = {} # index to value mapping + self.positions = {} # value to index mapping + self.sift_limit = sift_limit + + def in_dict(self, v): + return v in self.dict.values() + + def siftDown(self, pos): + v = self.dict[pos] + left = pos * 2 + if left >= len(self.dict): + return None + left_v = self.dict[left] + right = pos * 2 + 1 + if right >= len(self.dict) or self.freqs[left_v] < self.freqs[self.dict[right]]: + min_child = left + min_v = left_v + else: + min_child = right + min_v = self.dict[right] + + if self.freqs[min_v] < self.freqs[v]: + # swap + self.positions[min_v] = pos + self.dict[pos] = min_v + self.positions[v] = min_child + self.dict[min_child] = v + return min_child + else: + return None + + def siftUp(self, pos): + if pos == 0: + return None + v = self.dict[pos] + parent = pos // 2 + parent_v = self.dict[parent] + + if self.freqs[parent_v] > self.freqs[v]: + # swap + self.positions[parent_v] = pos + self.dict[pos] = parent_v + self.positions[v] = parent + self.dict[parent] = v + return parent + else: + return None + + def process(self, v): + self.freqs[v] = self.freqs.get(v, 0) + 1 + if v in self.positions: + p = self.positions[v] + while p is not None: + p = self.siftDown(p) + elif len(self.dict) < DICT_SIZE: + self.dict[len(self.dict)] = v + self.positions[v] = len(self.dict) - 1 + p = len(self.dict) - 1 + while p is not None: + p = self.siftUp(p) + elif self.freqs[v] > self.freqs[self.dict[0]]: + last_value = self.dict[0] + del self.positions[last_value] + self.positions[v] = 0 + self.dict[0] = v + p = 0 + while p is not None: + p = self.siftDown(p) + # When a new value comes in, it randomly replaces an existing one class AlgoRandomReplace(): def __init__(self): @@ -97,3 +183,30 @@ def process(self, v): print("Testing 'replace-by-frequency' algorithm") hits = evaluate_dict_algorithm(AlgoFreqReplace) print("{} of {} hits".format(hits, SAMPLE_COUNT)) + print("Testing 'heap' algorithm") + hits = evaluate_dict_algorithm(AlgoHeap) + print("{} of {} hits".format(hits, SAMPLE_COUNT)) + + trials = 100 # Monte-Carlo trials + trace_interval = 10 + print("Large testing with trials = {}".format(trials)) + # print("Testing 'Top N' algorithm") + # hits = evaluate_dict_algorithm(AlgoTopn, trials) + # print("{} of {} hits".format(hits, trials * SAMPLE_COUNT)) + # print("Testing 'random replace' algorithm") + # hits = evaluate_dict_algorithm(AlgoRandomReplace, trials) + # print("{} of {} hits".format(hits, trials * SAMPLE_COUNT)) + print("Testing 'replace-by-frequency' algorithm") + hits, trace_rf = evaluate_dict_algorithm(AlgoFreqReplace, trials, trace_interval) + print("{} of {} hits".format(hits, trials * SAMPLE_COUNT)) + print("Testing 'heap' algorithm") + hits, trace_heap = evaluate_dict_algorithm(AlgoHeap, trials, trace_interval) + print("{} of {} hits".format(hits, trials * SAMPLE_COUNT)) + + import matplotlib.pyplot as plt + plt.plot(range(0, len(trace_rf) * trace_interval, trace_interval), [x / trials / (i + 1) / trace_interval for i, x in enumerate(trace_rf)], label="Replace-by-frequency") + plt.plot(range(0, len(trace_rf) * trace_interval, trace_interval), [x / trials / (i + 1) / trace_interval for i, x in enumerate(trace_heap)], label="Optimal") + plt.xlabel("Sample") + plt.ylabel("Hit Rate") + plt.legend() + plt.show() \ No newline at end of file