forked from data61/MP-SPDZ
-
Notifications
You must be signed in to change notification settings - Fork 1
/
tutorial.mpc
139 lines (98 loc) · 3.12 KB
/
tutorial.mpc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# sint: secret integers
# see also https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint
# you can assign public numbers to sint
a = sint(1)
b = sint(2)
def test(actual, expected):
# you can reveal a number in order to print it
actual = actual.reveal()
print_ln('expected %s, got %s', expected, actual)
# private inputs are read from Player-Data/Input-P<i>-0
# or from standard input if using command-line option -I
# see https://mp-spdz.readthedocs.io/en/latest/io.html for more options
for i in 0, 1:
print_ln('got %s from player %s', sint.get_input_from(i).reveal(), i)
# some arithmetic works as expected
test(a + b, 3)
test(a * b, 2)
test(a - b, -1)
# Division can mean different things in different domains
# and there has be a specified bit length in some,
# so we use int_div() for integer division.
# k-bit division requires (4k+1)-bit computation.
test(b.int_div(a, 15), 2)
# comparisons produce 1 for true and 0 for false
test(a < b, 1)
test(a <= b, 1)
test(a >= b, 0)
test(a > b, 0)
test(a == b, 0)
test(a != b, 1)
# if_else() can be used instead of branching
# let's find out the larger number
test((a < b).if_else(b, a), 2)
# arrays and loops work as follows
a = Array(100, sint)
@for_range(100)
def f(i):
a[i] = sint(i) * sint(i - 1)
test(a[99], 99 * 98)
# if you use loops, use Array to store results
# don't do this
# @for_range(100)
# def f(i):
# a = sint(i)
# test(a, 99)
# sfix: fixed-point numbers
# see also https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sfix
# set the precision after the dot and in total
sfix.set_precision(16, 31)
# and the output precision in decimal digits
print_float_precision(4)
# you can do all basic arithmetic with sfix, including division
a = sfix(2)
b = sfix(-0.1)
test(a + b, 1.9)
test(a - b, 2.1)
test(a * b, -0.2)
test(a / b, -20)
test(a < b, 0)
test(a <= b, 0)
test(a >= b, 1)
test(a > b, 1)
test(a == b, 0)
test(a != b, 1)
test((a < b).if_else(a, b), -0.1)
# now let's do a computation with private inputs
# party 0 supplies three number and party 1 supplies three percentages
# we want to compute the weighted mean
print_ln('Party 0: please input three numbers not adding up to zero')
print_ln('Party 1: please input any three numbers')
data = Matrix(3, 2, sfix)
# use @for_range_opt for balanced optimization
# but use Python loops if compile-time numbers are need (e.g., for players)
@for_range_opt(3)
def _(i):
for j in range(2):
data[i][j] = sfix.get_input_from(j)
# compute weighted average
weight_total = sum(point[0] for point in data)
result = sum(point[0] * point[1] for point in data) / weight_total
# branching is supported also depending on revealed secret data
# with garbled circuits this triggers a interruption of the garbling
@if_e((sum(point[0] for point in data) != 0).reveal())
def _():
print_ln('weighted average: %s', result.reveal())
@else_
def _():
print_ln('your inputs made no sense')
# permutation matrix
M = Matrix(2, 2, sfix)
M[0][0] = 0
M[1][0] = 1
M[0][1] = 1
M[1][1] = 0
# matrix multiplication
M = data * M
test(M[0][0], data[0][1].reveal())
test(M[1][1], data[1][0].reveal())