@@ -64,13 +64,10 @@ class MulCtCtOp : public OpKernel {
64
64
op_ctx, bcast.IsValid (),
65
65
InvalidArgument (" Invalid broadcast between " , a.shape ().DebugString (),
66
66
" and " , b.shape ().DebugString ()));
67
- auto flat_a = MyBFlat<Variant>(op_ctx, a, bcast.x_reshape (), bcast.x_bcast ());
68
- auto flat_b = MyBFlat<Variant>(op_ctx, b, bcast.y_reshape (), bcast.y_bcast ());
69
-
70
- // Check the inputs have the same shape.
71
- OP_REQUIRES (
72
- op_ctx, flat_a.size () == flat_b.size (),
73
- InvalidArgument (" Broadcasted inputs must have the same shape." ));
67
+ auto flat_a = a.flat <Variant>();
68
+ auto flat_b = b.flat <Variant>();
69
+ IndexConverterFunctor a_bcaster (bcast.output_shape (), a.shape ());
70
+ IndexConverterFunctor b_bcaster (bcast.output_shape (), b.shape ());
74
71
75
72
// Allocate the output tensor which is the same shape as each of the inputs.
76
73
Tensor* output;
@@ -81,14 +78,14 @@ class MulCtCtOp : public OpKernel {
81
78
// Multiply each pair of ciphertexts and store the result in the output.
82
79
for (int i = 0 ; i < flat_output.dimension (0 ); ++i) {
83
80
SymmetricCtVariant<T> const * ct_a_var =
84
- std::move (flat_a (i ).get <SymmetricCtVariant<T>>());
81
+ std::move (flat_a (a_bcaster (i) ).get <SymmetricCtVariant<T>>());
85
82
OP_REQUIRES (op_ctx, ct_a_var != nullptr ,
86
83
InvalidArgument (" SymmetricCtVariant at flat index:" , i,
87
84
" for input a did not unwrap successfully." ));
88
85
SymmetricCt const & ct_a = ct_a_var->ct ;
89
86
90
87
SymmetricCtVariant<T> const * ct_b_var =
91
- std::move (flat_b (i ).get <SymmetricCtVariant<T>>());
88
+ std::move (flat_b (b_bcaster (i) ).get <SymmetricCtVariant<T>>());
92
89
OP_REQUIRES (op_ctx, ct_b_var != nullptr ,
93
90
InvalidArgument (" SymmetricCtVariant at flat index:" , i,
94
91
" for input b did not unwrap successfully." ));
@@ -124,46 +121,63 @@ class MulCtPtOp : public OpKernel {
124
121
op_ctx, bcast.IsValid (),
125
122
InvalidArgument (" Invalid broadcast between " , a.shape ().DebugString (),
126
123
" and " , b.shape ().DebugString ()));
127
- auto flat_a = MyBFlat<Variant>(op_ctx, a, bcast.x_reshape (), bcast.x_bcast ());
128
- auto flat_b = MyBFlat<Variant>(op_ctx, b, bcast.y_reshape (), bcast.y_bcast ());
129
-
130
- // Check the inputs have the same shape.
131
- OP_REQUIRES (
132
- op_ctx, flat_a.size () == flat_b.size (),
133
- InvalidArgument (" Broadcasted inputs must have the same shape." ));
124
+ auto flat_a = a.flat <Variant>();
125
+ auto flat_b = b.flat <Variant>();
126
+ IndexConverterFunctor a_bcaster (bcast.output_shape (), a.shape ());
127
+ IndexConverterFunctor b_bcaster (bcast.output_shape (), b.shape ());
134
128
135
129
// Allocate the output tensor which is the same shape as each of the inputs.
136
130
Tensor* output;
137
131
TensorShape output_shape = BCast::ToShape (bcast.output_shape ());
138
132
OP_REQUIRES_OK (op_ctx, op_ctx->allocate_output (0 , output_shape, &output));
139
133
auto flat_output = output->flat <Variant>();
140
134
141
- for (int i = 0 ; i < flat_output.dimension (0 ); ++i) {
142
- SymmetricCtVariant<T> const * ct_a_var =
143
- std::move (flat_a (i).get <SymmetricCtVariant<T>>());
144
- OP_REQUIRES (op_ctx, ct_a_var != nullptr ,
145
- InvalidArgument (" SymmetricCtVariant at flat index:" , i,
146
- " for input a did not unwrap successfully." ));
147
- SymmetricCt const & ct_a = ct_a_var->ct ;
135
+ // Recover num_slots from first ciphertext.
136
+ SymmetricCtVariant<T> const * ct_var =
137
+ std::move (flat_a (0 ).get <SymmetricCtVariant<T>>());
138
+ OP_REQUIRES (
139
+ op_ctx, ct_var != nullptr ,
140
+ InvalidArgument (" SymmetricCtVariant a did not unwrap successfully." ));
141
+ SymmetricCt const & ct = ct_var->ct ;
142
+ int num_slots = 1 << ct.LogN ();
143
+ int num_components = ct.NumModuli ();
148
144
149
- PolynomialVariant<T> const * pv_b_var =
150
- std::move (flat_b (i).get <PolynomialVariant<T>>());
151
- OP_REQUIRES (op_ctx, pv_b_var != nullptr ,
152
- InvalidArgument (" PolynomialVariant at flat index:" , i,
153
- " for input b did not unwrap successfully." ));
154
- RnsPolynomial const & pt_b = pv_b_var->poly ;
145
+ auto mul_in_range = [&](int start, int end) {
146
+ for (int i = start; i < end; ++i) {
147
+ SymmetricCtVariant<T> const * ct_a_var =
148
+ std::move (flat_a (a_bcaster (i)).get <SymmetricCtVariant<T>>());
149
+ OP_REQUIRES (
150
+ op_ctx, ct_a_var != nullptr ,
151
+ InvalidArgument (" SymmetricCtVariant at flat index:" , i,
152
+ " for input a did not unwrap successfully." ));
153
+ SymmetricCt const & ct_a = ct_a_var->ct ;
155
154
156
- OP_REQUIRES_VALUE (SymmetricCt ct_c, op_ctx,
157
- ct_a * pt_b); // shell absorb operation
155
+ PolynomialVariant<T> const * pv_b_var =
156
+ std::move (flat_b (b_bcaster (i)).get <PolynomialVariant<T>>());
157
+ OP_REQUIRES (
158
+ op_ctx, pv_b_var != nullptr ,
159
+ InvalidArgument (" PolynomialVariant at flat index:" , i,
160
+ " for input b did not unwrap successfully." ));
161
+ RnsPolynomial const & pt_b = pv_b_var->poly ;
158
162
159
- SymmetricCtVariant ct_c_var (std::move (ct_c));
160
- flat_output (i) = std::move (ct_c_var);
161
- }
163
+ OP_REQUIRES_VALUE (SymmetricCt ct_c, op_ctx,
164
+ ct_a * pt_b); // shell absorb operation
165
+
166
+ SymmetricCtVariant ct_c_var (std::move (ct_c));
167
+ flat_output (i) = std::move (ct_c_var);
168
+ }
169
+ };
170
+
171
+ auto thread_pool =
172
+ op_ctx->device ()->tensorflow_cpu_worker_threads ()->workers ;
173
+ int const cost_per_mul = 30 * num_slots * num_components;
174
+ thread_pool->ParallelFor (flat_output.dimension (0 ), cost_per_mul,
175
+ mul_in_range);
162
176
}
163
177
};
164
178
165
- // This Op can multiply either a shell ciphertext or a plaintext polynomial by a
166
- // plaintext scalar, depending on the class template.
179
+ // This Op can multiply either a shell ciphertext or a plaintext polynomial by
180
+ // a plaintext scalar, depending on the class template.
167
181
template <typename T, typename PtT, typename CtOrPolyVariant>
168
182
class MulShellTfScalarOp : public OpKernel {
169
183
private:
@@ -194,12 +208,8 @@ class MulShellTfScalarOp : public OpKernel {
194
208
InvalidArgument (" Invalid broadcast between " , a.shape ().DebugString (),
195
209
" and " , b.shape ().DebugString ()));
196
210
auto flat_a = a.flat <Variant>(); // a is not broadcasted, just b.
197
- auto flat_b = MyBFlat<PtT>(op_ctx, b, bcast.y_reshape (), bcast.y_bcast ());
198
-
199
- // Check the inputs have the same shape.
200
- OP_REQUIRES (
201
- op_ctx, flat_a.size () == flat_b.size (),
202
- InvalidArgument (" Broadcasted inputs must have the same shape." ));
211
+ auto flat_b = b.flat <PtT>();
212
+ IndexConverterFunctor b_bcaster (bcast.output_shape (), b.shape ());
203
213
204
214
// Allocate the output tensor which is the same shape as the first input.
205
215
Tensor* output;
@@ -211,7 +221,7 @@ class MulShellTfScalarOp : public OpKernel {
211
221
// First encode the scalar b
212
222
// TDOO(jchoncholas): encode all scalars at once beforehand.
213
223
T wrapped_b;
214
- EncodeScalar (op_ctx, flat_b (i ), encoder, &wrapped_b);
224
+ EncodeScalar (op_ctx, flat_b (b_bcaster (i) ), encoder, &wrapped_b);
215
225
216
226
CtOrPolyVariant const * ct_or_pt_var =
217
227
std::move (flat_a (i).get <CtOrPolyVariant>());
@@ -241,7 +251,7 @@ class MulShellTfScalarOp : public OpKernel {
241
251
}
242
252
}
243
253
244
- private:
254
+ private:
245
255
void EncodeScalar (OpKernelContext* op_ctx, PtT const & val, Encoder const * encoder, T* wrapped_val) {
246
256
if constexpr (std::is_signed<PtT>::value) {
247
257
// SHELL is built on the assumption that the plaintext type (in this
@@ -293,30 +303,28 @@ class MulPtPtOp : public OpKernel {
293
303
op_ctx, bcast.IsValid (),
294
304
InvalidArgument (" Invalid broadcast between " , a.shape ().DebugString (),
295
305
" and " , b.shape ().DebugString ()));
296
- auto flat_a = MyBFlat<Variant>(op_ctx, a, bcast.x_reshape (), bcast.x_bcast ());
297
- auto flat_b = MyBFlat<Variant>(op_ctx, b, bcast.y_reshape (), bcast.y_bcast ());
298
-
299
- // Check the inputs have the same shape.
300
- OP_REQUIRES (
301
- op_ctx, flat_a.size () == flat_b.size (),
302
- InvalidArgument (" Broadcasted inputs must have the same shape." ));
306
+ auto flat_a = a.flat <Variant>();
307
+ auto flat_b = b.flat <Variant>();
308
+ IndexConverterFunctor a_bcaster (bcast.output_shape (), a.shape ());
309
+ IndexConverterFunctor b_bcaster (bcast.output_shape (), b.shape ());
303
310
304
- // Allocate the output tensor which is the same shape as each of the inputs.
311
+ // Allocate the output tensor which is the same shape as each of the
312
+ // inputs.
305
313
Tensor* output;
306
314
TensorShape output_shape = BCast::ToShape (bcast.output_shape ());
307
315
OP_REQUIRES_OK (op_ctx, op_ctx->allocate_output (0 , output_shape, &output));
308
316
auto flat_output = output->flat <Variant>();
309
317
310
318
for (int i = 0 ; i < flat_output.dimension (0 ); ++i) {
311
319
PolynomialVariant<T> const * pv_a_var =
312
- std::move (flat_a (i ).get <PolynomialVariant<T>>());
320
+ std::move (flat_a (a_bcaster (i) ).get <PolynomialVariant<T>>());
313
321
OP_REQUIRES (op_ctx, pv_a_var != nullptr ,
314
322
InvalidArgument (" PolynomialVariant at flat index:" , i,
315
323
" for input a did not unwrap successfully." ));
316
324
RnsPolynomial const & pt_a = pv_a_var->poly ;
317
325
318
326
PolynomialVariant<T> const * pv_b_var =
319
- std::move (flat_b (i ).get <PolynomialVariant<T>>());
327
+ std::move (flat_b (b_bcaster (i) ).get <PolynomialVariant<T>>());
320
328
OP_REQUIRES (op_ctx, pv_b_var != nullptr ,
321
329
InvalidArgument (" PolynomialVariant at flat index:" , i,
322
330
" for input b did not unwrap successfully." ));
0 commit comments