Skip to content

Commit c7e13a2

Browse files
Op level broadcasting does not require copy.
1 parent 829d8cf commit c7e13a2

File tree

4 files changed

+145
-129
lines changed

4 files changed

+145
-129
lines changed

tf_shell/cc/kernels/add_kernels.cc

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,10 @@ class AddCtCtOp : public OpKernel {
106106
op_ctx, bcast.IsValid(),
107107
InvalidArgument("Invalid broadcast between ", a.shape().DebugString(),
108108
" and ", b.shape().DebugString()));
109-
auto flat_a = MyBFlat<Variant>(op_ctx, a, bcast.x_reshape(), bcast.x_bcast());
110-
auto flat_b = MyBFlat<Variant>(op_ctx, b, bcast.y_reshape(), bcast.y_bcast());
111-
112-
// Check the inputs have the same shape.
113-
OP_REQUIRES(
114-
op_ctx, flat_a.size() == flat_b.size(),
115-
InvalidArgument("Broadcasted inputs must have the same shape."));
109+
auto flat_a = a.flat<Variant>();
110+
auto flat_b = b.flat<Variant>();
111+
IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape());
112+
IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape());
116113

117114
// Allocate the output tensor which is the same size as one of the inputs.
118115
Tensor* output;
@@ -122,14 +119,14 @@ class AddCtCtOp : public OpKernel {
122119

123120
for (int i = 0; i < flat_output.dimension(0); ++i) {
124121
SymmetricCtVariant<T> const* ct_a_var =
125-
std::move(flat_a(i).get<SymmetricCtVariant<T>>());
122+
std::move(flat_a(a_bcaster(i)).get<SymmetricCtVariant<T>>());
126123
OP_REQUIRES(op_ctx, ct_a_var != nullptr,
127124
InvalidArgument("SymmetricCtVariant at flat index: ", i,
128125
" for input a did not unwrap successfully."));
129126
SymmetricCt const& ct_a = ct_a_var->ct;
130127

131128
SymmetricCtVariant<T> const* ct_b_var =
132-
std::move(flat_b(i).get<SymmetricCtVariant<T>>());
129+
std::move(flat_b(b_bcaster(i)).get<SymmetricCtVariant<T>>());
133130
OP_REQUIRES(op_ctx, ct_b_var != nullptr,
134131
InvalidArgument("SymmetricCtVariant at flat index: ", i,
135132
" for input b did not unwrap successfully."));
@@ -166,13 +163,10 @@ class AddCtPtOp : public OpKernel {
166163
op_ctx, bcast.IsValid(),
167164
InvalidArgument("Invalid broadcast between ", a.shape().DebugString(),
168165
" and ", b.shape().DebugString()));
169-
auto flat_a = MyBFlat<Variant>(op_ctx, a, bcast.x_reshape(), bcast.x_bcast());
170-
auto flat_b = MyBFlat<Variant>(op_ctx, b, bcast.y_reshape(), bcast.y_bcast());
171-
172-
// Check the inputs have the same shape.
173-
OP_REQUIRES(
174-
op_ctx, flat_a.size() == flat_b.size(),
175-
InvalidArgument("Broadcasted inputs must have the same shape."));
166+
auto flat_a = a.flat<Variant>();
167+
auto flat_b = b.flat<Variant>();
168+
IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape());
169+
IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape());
176170

177171
// Allocate the output tensor which is the same size as one of the inputs.
178172
Tensor* output;
@@ -182,14 +176,14 @@ class AddCtPtOp : public OpKernel {
182176

183177
for (int i = 0; i < flat_output.dimension(0); ++i) {
184178
SymmetricCtVariant<T> const* ct_a_var =
185-
std::move(flat_a(i).get<SymmetricCtVariant<T>>());
179+
std::move(flat_a(a_bcaster(i)).get<SymmetricCtVariant<T>>());
186180
OP_REQUIRES(op_ctx, ct_a_var != nullptr,
187181
InvalidArgument("SymmetricCtVariant at flat index: ", i,
188182
" for input a did not unwrap successfully."));
189183
SymmetricCt const& ct_a = ct_a_var->ct;
190184

191185
PolynomialVariant<T> const* pv_b_var =
192-
std::move(flat_b(i).get<PolynomialVariant<T>>());
186+
std::move(flat_b(b_bcaster(i)).get<PolynomialVariant<T>>());
193187
OP_REQUIRES(op_ctx, pv_b_var != nullptr,
194188
InvalidArgument("PolynomialVariant at flat index: ", i,
195189
" for input b did not unwrap successfully."));
@@ -229,13 +223,10 @@ class AddPtPtOp : public OpKernel {
229223
op_ctx, bcast.IsValid(),
230224
InvalidArgument("Invalid broadcast between ", a.shape().DebugString(),
231225
" and ", b.shape().DebugString()));
232-
auto flat_a = MyBFlat<Variant>(op_ctx, a, bcast.x_reshape(), bcast.x_bcast());
233-
auto flat_b = MyBFlat<Variant>(op_ctx, b, bcast.y_reshape(), bcast.y_bcast());
234-
235-
// Check the inputs have the same shape.
236-
OP_REQUIRES(
237-
op_ctx, flat_a.size() == flat_b.size(),
238-
InvalidArgument("Broadcasted inputs must have the same shape."));
226+
auto flat_a = a.flat<Variant>();
227+
auto flat_b = b.flat<Variant>();
228+
IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape());
229+
IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape());
239230

240231
// Allocate the output tensor which is the same size as one of the inputs.
241232
Tensor* output;
@@ -245,14 +236,14 @@ class AddPtPtOp : public OpKernel {
245236

246237
for (int i = 0; i < flat_output.dimension(0); ++i) {
247238
PolynomialVariant<T> const* pv_a_var =
248-
std::move(flat_a(i).get<PolynomialVariant<T>>());
239+
std::move(flat_a(a_bcaster(i)).get<PolynomialVariant<T>>());
249240
OP_REQUIRES(op_ctx, pv_a_var != nullptr,
250241
InvalidArgument("PolynomialVariant at flat index: ", i,
251242
" for input a did not unwrap successfully."));
252243
RnsPolynomial const& pt_a = pv_a_var->poly;
253244

254245
PolynomialVariant<T> const* pv_b_var =
255-
std::move(flat_b(i).get<PolynomialVariant<T>>());
246+
std::move(flat_b(b_bcaster(i)).get<PolynomialVariant<T>>());
256247
OP_REQUIRES(op_ctx, pv_b_var != nullptr,
257248
InvalidArgument("PolynomialVariant at flat index: ", i,
258249
" for input b did not unwrap successfully."));

tf_shell/cc/kernels/mul_kernels.cc

Lines changed: 62 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,10 @@ class MulCtCtOp : public OpKernel {
6464
op_ctx, bcast.IsValid(),
6565
InvalidArgument("Invalid broadcast between ", a.shape().DebugString(),
6666
" and ", b.shape().DebugString()));
67-
auto flat_a = MyBFlat<Variant>(op_ctx, a, bcast.x_reshape(), bcast.x_bcast());
68-
auto flat_b = MyBFlat<Variant>(op_ctx, b, bcast.y_reshape(), bcast.y_bcast());
69-
70-
// Check the inputs have the same shape.
71-
OP_REQUIRES(
72-
op_ctx, flat_a.size() == flat_b.size(),
73-
InvalidArgument("Broadcasted inputs must have the same shape."));
67+
auto flat_a = a.flat<Variant>();
68+
auto flat_b = b.flat<Variant>();
69+
IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape());
70+
IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape());
7471

7572
// Allocate the output tensor which is the same shape as each of the inputs.
7673
Tensor* output;
@@ -81,14 +78,14 @@ class MulCtCtOp : public OpKernel {
8178
// Multiply each pair of ciphertexts and store the result in the output.
8279
for (int i = 0; i < flat_output.dimension(0); ++i) {
8380
SymmetricCtVariant<T> const* ct_a_var =
84-
std::move(flat_a(i).get<SymmetricCtVariant<T>>());
81+
std::move(flat_a(a_bcaster(i)).get<SymmetricCtVariant<T>>());
8582
OP_REQUIRES(op_ctx, ct_a_var != nullptr,
8683
InvalidArgument("SymmetricCtVariant at flat index:", i,
8784
" for input a did not unwrap successfully."));
8885
SymmetricCt const& ct_a = ct_a_var->ct;
8986

9087
SymmetricCtVariant<T> const* ct_b_var =
91-
std::move(flat_b(i).get<SymmetricCtVariant<T>>());
88+
std::move(flat_b(b_bcaster(i)).get<SymmetricCtVariant<T>>());
9289
OP_REQUIRES(op_ctx, ct_b_var != nullptr,
9390
InvalidArgument("SymmetricCtVariant at flat index:", i,
9491
" for input b did not unwrap successfully."));
@@ -124,46 +121,63 @@ class MulCtPtOp : public OpKernel {
124121
op_ctx, bcast.IsValid(),
125122
InvalidArgument("Invalid broadcast between ", a.shape().DebugString(),
126123
" and ", b.shape().DebugString()));
127-
auto flat_a = MyBFlat<Variant>(op_ctx, a, bcast.x_reshape(), bcast.x_bcast());
128-
auto flat_b = MyBFlat<Variant>(op_ctx, b, bcast.y_reshape(), bcast.y_bcast());
129-
130-
// Check the inputs have the same shape.
131-
OP_REQUIRES(
132-
op_ctx, flat_a.size() == flat_b.size(),
133-
InvalidArgument("Broadcasted inputs must have the same shape."));
124+
auto flat_a = a.flat<Variant>();
125+
auto flat_b = b.flat<Variant>();
126+
IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape());
127+
IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape());
134128

135129
// Allocate the output tensor which is the same shape as each of the inputs.
136130
Tensor* output;
137131
TensorShape output_shape = BCast::ToShape(bcast.output_shape());
138132
OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output));
139133
auto flat_output = output->flat<Variant>();
140134

141-
for (int i = 0; i < flat_output.dimension(0); ++i) {
142-
SymmetricCtVariant<T> const* ct_a_var =
143-
std::move(flat_a(i).get<SymmetricCtVariant<T>>());
144-
OP_REQUIRES(op_ctx, ct_a_var != nullptr,
145-
InvalidArgument("SymmetricCtVariant at flat index:", i,
146-
" for input a did not unwrap successfully."));
147-
SymmetricCt const& ct_a = ct_a_var->ct;
135+
// Recover num_slots from first ciphertext.
136+
SymmetricCtVariant<T> const* ct_var =
137+
std::move(flat_a(0).get<SymmetricCtVariant<T>>());
138+
OP_REQUIRES(
139+
op_ctx, ct_var != nullptr,
140+
InvalidArgument("SymmetricCtVariant a did not unwrap successfully."));
141+
SymmetricCt const& ct = ct_var->ct;
142+
int num_slots = 1 << ct.LogN();
143+
int num_components = ct.NumModuli();
148144

149-
PolynomialVariant<T> const* pv_b_var =
150-
std::move(flat_b(i).get<PolynomialVariant<T>>());
151-
OP_REQUIRES(op_ctx, pv_b_var != nullptr,
152-
InvalidArgument("PolynomialVariant at flat index:", i,
153-
" for input b did not unwrap successfully."));
154-
RnsPolynomial const& pt_b = pv_b_var->poly;
145+
auto mul_in_range = [&](int start, int end) {
146+
for (int i = start; i < end; ++i) {
147+
SymmetricCtVariant<T> const* ct_a_var =
148+
std::move(flat_a(a_bcaster(i)).get<SymmetricCtVariant<T>>());
149+
OP_REQUIRES(
150+
op_ctx, ct_a_var != nullptr,
151+
InvalidArgument("SymmetricCtVariant at flat index:", i,
152+
" for input a did not unwrap successfully."));
153+
SymmetricCt const& ct_a = ct_a_var->ct;
155154

156-
OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx,
157-
ct_a * pt_b); // shell absorb operation
155+
PolynomialVariant<T> const* pv_b_var =
156+
std::move(flat_b(b_bcaster(i)).get<PolynomialVariant<T>>());
157+
OP_REQUIRES(
158+
op_ctx, pv_b_var != nullptr,
159+
InvalidArgument("PolynomialVariant at flat index:", i,
160+
" for input b did not unwrap successfully."));
161+
RnsPolynomial const& pt_b = pv_b_var->poly;
158162

159-
SymmetricCtVariant ct_c_var(std::move(ct_c));
160-
flat_output(i) = std::move(ct_c_var);
161-
}
163+
OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx,
164+
ct_a * pt_b); // shell absorb operation
165+
166+
SymmetricCtVariant ct_c_var(std::move(ct_c));
167+
flat_output(i) = std::move(ct_c_var);
168+
}
169+
};
170+
171+
auto thread_pool =
172+
op_ctx->device()->tensorflow_cpu_worker_threads()->workers;
173+
int const cost_per_mul = 30 * num_slots * num_components;
174+
thread_pool->ParallelFor(flat_output.dimension(0), cost_per_mul,
175+
mul_in_range);
162176
}
163177
};
164178

165-
// This Op can multiply either a shell ciphertext or a plaintext polynomial by a
166-
// plaintext scalar, depending on the class template.
179+
// This Op can multiply either a shell ciphertext or a plaintext polynomial by
180+
// a plaintext scalar, depending on the class template.
167181
template <typename T, typename PtT, typename CtOrPolyVariant>
168182
class MulShellTfScalarOp : public OpKernel {
169183
private:
@@ -194,12 +208,8 @@ class MulShellTfScalarOp : public OpKernel {
194208
InvalidArgument("Invalid broadcast between ", a.shape().DebugString(),
195209
" and ", b.shape().DebugString()));
196210
auto flat_a = a.flat<Variant>(); // a is not broadcasted, just b.
197-
auto flat_b = MyBFlat<PtT>(op_ctx, b, bcast.y_reshape(), bcast.y_bcast());
198-
199-
// Check the inputs have the same shape.
200-
OP_REQUIRES(
201-
op_ctx, flat_a.size() == flat_b.size(),
202-
InvalidArgument("Broadcasted inputs must have the same shape."));
211+
auto flat_b = b.flat<PtT>();
212+
IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape());
203213

204214
// Allocate the output tensor which is the same shape as the first input.
205215
Tensor* output;
@@ -211,7 +221,7 @@ class MulShellTfScalarOp : public OpKernel {
211221
// First encode the scalar b
212222
// TDOO(jchoncholas): encode all scalars at once beforehand.
213223
T wrapped_b;
214-
EncodeScalar(op_ctx, flat_b(i), encoder, &wrapped_b);
224+
EncodeScalar(op_ctx, flat_b(b_bcaster(i)), encoder, &wrapped_b);
215225

216226
CtOrPolyVariant const* ct_or_pt_var =
217227
std::move(flat_a(i).get<CtOrPolyVariant>());
@@ -241,7 +251,7 @@ class MulShellTfScalarOp : public OpKernel {
241251
}
242252
}
243253

244-
private:
254+
private:
245255
void EncodeScalar(OpKernelContext* op_ctx, PtT const& val, Encoder const* encoder, T* wrapped_val) {
246256
if constexpr (std::is_signed<PtT>::value) {
247257
// SHELL is built on the assumption that the plaintext type (in this
@@ -293,30 +303,28 @@ class MulPtPtOp : public OpKernel {
293303
op_ctx, bcast.IsValid(),
294304
InvalidArgument("Invalid broadcast between ", a.shape().DebugString(),
295305
" and ", b.shape().DebugString()));
296-
auto flat_a = MyBFlat<Variant>(op_ctx, a, bcast.x_reshape(), bcast.x_bcast());
297-
auto flat_b = MyBFlat<Variant>(op_ctx, b, bcast.y_reshape(), bcast.y_bcast());
298-
299-
// Check the inputs have the same shape.
300-
OP_REQUIRES(
301-
op_ctx, flat_a.size() == flat_b.size(),
302-
InvalidArgument("Broadcasted inputs must have the same shape."));
306+
auto flat_a = a.flat<Variant>();
307+
auto flat_b = b.flat<Variant>();
308+
IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape());
309+
IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape());
303310

304-
// Allocate the output tensor which is the same shape as each of the inputs.
311+
// Allocate the output tensor which is the same shape as each of the
312+
// inputs.
305313
Tensor* output;
306314
TensorShape output_shape = BCast::ToShape(bcast.output_shape());
307315
OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output));
308316
auto flat_output = output->flat<Variant>();
309317

310318
for (int i = 0; i < flat_output.dimension(0); ++i) {
311319
PolynomialVariant<T> const* pv_a_var =
312-
std::move(flat_a(i).get<PolynomialVariant<T>>());
320+
std::move(flat_a(a_bcaster(i)).get<PolynomialVariant<T>>());
313321
OP_REQUIRES(op_ctx, pv_a_var != nullptr,
314322
InvalidArgument("PolynomialVariant at flat index:", i,
315323
" for input a did not unwrap successfully."));
316324
RnsPolynomial const& pt_a = pv_a_var->poly;
317325

318326
PolynomialVariant<T> const* pv_b_var =
319-
std::move(flat_b(i).get<PolynomialVariant<T>>());
327+
std::move(flat_b(b_bcaster(i)).get<PolynomialVariant<T>>());
320328
OP_REQUIRES(op_ctx, pv_b_var != nullptr,
321329
InvalidArgument("PolynomialVariant at flat index:", i,
322330
" for input b did not unwrap successfully."));

0 commit comments

Comments
 (0)