Skip to content

Commit f6f9e50

Browse files
Rerun examples with new op broadcasting.
1 parent c7e13a2 commit f6f9e50

File tree

6 files changed

+1370
-183
lines changed

6 files changed

+1370
-183
lines changed

examples/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tflogs

examples/benchmark.ipynb

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Introduction to tf-shell\n",
8+
"\n",
9+
"To get started, `pip install tf-shell`. tf-shell has a few modules, the one used\n",
10+
"in this notebook is `tf_shell`."
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": 1,
16+
"metadata": {},
17+
"outputs": [
18+
{
19+
"name": "stderr",
20+
"output_type": "stream",
21+
"text": [
22+
"2024-06-10 21:30:04.256734: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
23+
"2024-06-10 21:30:04.257664: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
24+
"2024-06-10 21:30:04.291533: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
25+
"2024-06-10 21:30:04.428601: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
26+
"To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
27+
"2024-06-10 21:30:05.195988: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
28+
]
29+
}
30+
],
31+
"source": [
32+
"import tf_shell\n",
33+
"import tensorflow as tf\n",
34+
"import timeit\n",
35+
"\n",
36+
"context = tf_shell.create_context64(\n",
37+
" log_n=10,\n",
38+
" main_moduli=[8556589057, 8388812801],\n",
39+
" plaintext_modulus=40961,\n",
40+
" scaling_factor=3,\n",
41+
" mul_depth_supported=3,\n",
42+
" seed=\"test_seed\",\n",
43+
")\n",
44+
"\n",
45+
"secret_key = tf_shell.create_key64(context)\n",
46+
"rotation_key = tf_shell.create_rotation_key64(context, secret_key)\n",
47+
"\n",
48+
"a = tf.random.uniform([context.num_slots, 55555], dtype=tf.float32, maxval=10)\n",
49+
"b = tf.random.uniform([55555, 333], dtype=tf.float32, maxval=10)\n",
50+
"c = tf.random.uniform([2, context.num_slots], dtype=tf.float32, maxval=10)\n",
51+
"d = tf.random.uniform([context.num_slots, 4444], dtype=tf.float32, maxval=10)\n",
52+
"\n",
53+
"enc_a = tf_shell.to_encrypted(a, secret_key, context)"
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": 2,
59+
"metadata": {},
60+
"outputs": [
61+
{
62+
"name": "stdout",
63+
"output_type": "stream",
64+
"text": [
65+
"0.4906675929996709\n"
66+
]
67+
}
68+
],
69+
"source": [
70+
"def to_pt():\n",
71+
" return tf_shell.to_shell_plaintext(a, context)\n",
72+
"\n",
73+
"time = min(timeit.Timer(to_pt).repeat(repeat=3, number=1))\n",
74+
"print(time)"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": 3,
80+
"metadata": {},
81+
"outputs": [
82+
{
83+
"name": "stdout",
84+
"output_type": "stream",
85+
"text": [
86+
"5.263423050000711\n"
87+
]
88+
}
89+
],
90+
"source": [
91+
"def enc():\n",
92+
" return tf_shell.to_encrypted(d, secret_key, context)\n",
93+
"\n",
94+
"time = min(timeit.Timer(enc).repeat(repeat=3, number=1))\n",
95+
"print(time)"
96+
]
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": 4,
101+
"metadata": {},
102+
"outputs": [
103+
{
104+
"name": "stdout",
105+
"output_type": "stream",
106+
"text": [
107+
"0.5277276859997073\n"
108+
]
109+
}
110+
],
111+
"source": [
112+
"def dec():\n",
113+
" return tf_shell.to_tensorflow(enc_a, secret_key)\n",
114+
"\n",
115+
"time = min(timeit.Timer(dec).repeat(repeat=3, number=1))\n",
116+
"print(time)"
117+
]
118+
},
119+
{
120+
"cell_type": "code",
121+
"execution_count": 5,
122+
"metadata": {},
123+
"outputs": [
124+
{
125+
"name": "stdout",
126+
"output_type": "stream",
127+
"text": [
128+
"0.4192462440005329\n"
129+
]
130+
}
131+
],
132+
"source": [
133+
"def ct_ct_add():\n",
134+
" return enc_a + enc_a\n",
135+
"\n",
136+
"time = min(timeit.Timer(ct_ct_add).repeat(repeat=3, number=1))\n",
137+
"print(time)"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": 6,
143+
"metadata": {},
144+
"outputs": [
145+
{
146+
"name": "stdout",
147+
"output_type": "stream",
148+
"text": [
149+
"0.4219015720009338\n"
150+
]
151+
}
152+
],
153+
"source": [
154+
"def ct_ct_sub():\n",
155+
" return enc_a - enc_a\n",
156+
"\n",
157+
"time = min(timeit.Timer(ct_ct_sub).repeat(repeat=3, number=1))\n",
158+
"print(time)"
159+
]
160+
},
161+
{
162+
"cell_type": "code",
163+
"execution_count": 7,
164+
"metadata": {},
165+
"outputs": [
166+
{
167+
"name": "stdout",
168+
"output_type": "stream",
169+
"text": [
170+
"0.8668678089998139\n"
171+
]
172+
}
173+
],
174+
"source": [
175+
"def ct_ct_mul():\n",
176+
" return enc_a * enc_a\n",
177+
"\n",
178+
"time = min(timeit.Timer(ct_ct_mul).repeat(repeat=3, number=1))\n",
179+
"print(time)"
180+
]
181+
},
182+
{
183+
"cell_type": "code",
184+
"execution_count": 8,
185+
"metadata": {},
186+
"outputs": [
187+
{
188+
"name": "stdout",
189+
"output_type": "stream",
190+
"text": [
191+
"0.7579904609992809\n"
192+
]
193+
}
194+
],
195+
"source": [
196+
"def ct_pt_add():\n",
197+
" return enc_a + a\n",
198+
"\n",
199+
"time = min(timeit.Timer(ct_pt_add).repeat(repeat=3, number=1))\n",
200+
"print(time)"
201+
]
202+
},
203+
{
204+
"cell_type": "code",
205+
"execution_count": 9,
206+
"metadata": {},
207+
"outputs": [
208+
{
209+
"name": "stdout",
210+
"output_type": "stream",
211+
"text": [
212+
"0.6268679120003071\n"
213+
]
214+
}
215+
],
216+
"source": [
217+
"def ct_pt_mul():\n",
218+
" return enc_a * a\n",
219+
"\n",
220+
"time = min(timeit.Timer(ct_pt_mul).repeat(repeat=3, number=1))\n",
221+
"print(time)"
222+
]
223+
},
224+
{
225+
"cell_type": "code",
226+
"execution_count": 10,
227+
"metadata": {},
228+
"outputs": [
229+
{
230+
"name": "stdout",
231+
"output_type": "stream",
232+
"text": [
233+
"25.57404864599812\n"
234+
]
235+
}
236+
],
237+
"source": [
238+
"def ct_pt_matmul():\n",
239+
" return tf_shell.matmul(enc_a, b)\n",
240+
"\n",
241+
"time = min(timeit.Timer(ct_pt_matmul).repeat(repeat=3, number=1))\n",
242+
"print(time)"
243+
]
244+
},
245+
{
246+
"cell_type": "code",
247+
"execution_count": 11,
248+
"metadata": {},
249+
"outputs": [
250+
{
251+
"name": "stdout",
252+
"output_type": "stream",
253+
"text": [
254+
"361.1888753159983\n"
255+
]
256+
}
257+
],
258+
"source": [
259+
"def pt_ct_matmul():\n",
260+
" return tf_shell.matmul(c, enc_a, rotation_key)\n",
261+
"\n",
262+
"time = min(timeit.Timer(pt_ct_matmul).repeat(repeat=3, number=1))\n",
263+
"print(time)"
264+
]
265+
},
266+
{
267+
"cell_type": "code",
268+
"execution_count": 12,
269+
"metadata": {},
270+
"outputs": [
271+
{
272+
"name": "stdout",
273+
"output_type": "stream",
274+
"text": [
275+
"4.650902364999638\n"
276+
]
277+
}
278+
],
279+
"source": [
280+
"def ct_roll():\n",
281+
" return tf_shell.roll(enc_a, 2, rotation_key)\n",
282+
"\n",
283+
"time = min(timeit.Timer(ct_roll).repeat(repeat=3, number=1))\n",
284+
"print(time)"
285+
]
286+
}
287+
],
288+
"metadata": {
289+
"kernelspec": {
290+
"display_name": ".venv",
291+
"language": "python",
292+
"name": "python3"
293+
},
294+
"language_info": {
295+
"codemirror_mode": {
296+
"name": "ipython",
297+
"version": 3
298+
},
299+
"file_extension": ".py",
300+
"mimetype": "text/x-python",
301+
"name": "python",
302+
"nbconvert_exporter": "python",
303+
"pygments_lexer": "ipython3",
304+
"version": "3.10.12"
305+
}
306+
},
307+
"nbformat": 4,
308+
"nbformat_minor": 2
309+
}

examples/intro.ipynb

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,22 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 2,
15+
"execution_count": 1,
1616
"metadata": {},
17-
"outputs": [],
17+
"outputs": [
18+
{
19+
"name": "stderr",
20+
"output_type": "stream",
21+
"text": [
22+
"2024-06-10 21:54:19.630781: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
23+
"2024-06-10 21:54:19.631217: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
24+
"2024-06-10 21:54:19.633550: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
25+
"2024-06-10 21:54:19.663933: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
26+
"To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
27+
"2024-06-10 21:54:20.301503: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
28+
]
29+
}
30+
],
1831
"source": [
1932
"import tf_shell"
2033
]
@@ -38,7 +51,7 @@
3851
},
3952
{
4053
"cell_type": "code",
41-
"execution_count": 3,
54+
"execution_count": 2,
4255
"metadata": {},
4356
"outputs": [],
4457
"source": [
@@ -69,14 +82,14 @@
6982
},
7083
{
7184
"cell_type": "code",
72-
"execution_count": 4,
85+
"execution_count": 3,
7386
"metadata": {},
7487
"outputs": [
7588
{
7689
"name": "stdout",
7790
"output_type": "stream",
7891
"text": [
79-
"The first 3 elements of the data are [0.62601924 4.461747 5.8008575 ]\n"
92+
"The first 3 elements of the data are [0.05918741 3.8001454 5.9336624 ]\n"
8093
]
8194
}
8295
],
@@ -99,7 +112,7 @@
99112
},
100113
{
101114
"cell_type": "code",
102-
"execution_count": 5,
115+
"execution_count": 4,
103116
"metadata": {},
104117
"outputs": [],
105118
"source": [
@@ -132,16 +145,16 @@
132145
},
133146
{
134147
"cell_type": "code",
135-
"execution_count": 6,
148+
"execution_count": 5,
136149
"metadata": {},
137150
"outputs": [
138151
{
139152
"name": "stdout",
140153
"output_type": "stream",
141154
"text": [
142-
"enc: [0.6666667 4.3333335 5.6666665]\n",
143-
"enc + enc: [ 1.3333334 8.666667 11.333333 ]\n",
144-
"enc * enc: [ 0.44444445 18.777779 32.11111 ]\n"
155+
"enc: [0. 3.6666667 6. ]\n",
156+
"enc + enc: [ 0. 7.3333335 12. ]\n",
157+
"enc * enc: [ 0. 13.444445 36. ]\n"
145158
]
146159
}
147160
],

0 commit comments

Comments
 (0)