@@ -126,6 +126,36 @@ static inline void launch_vectorized_kernel(int64_t N, const func_t& f, array_t
126
126
127
127
// Jiterator functions are guarded behind this macro
128
128
#ifdef USE_JITERATOR
129
+
130
+ namespace {
131
+
132
+ template <typename Tuple, std::size_t ... I>
133
+ constexpr auto tuple_to_array_helper (Tuple& t, std::index_sequence<I...> seq) {
134
+ constexpr auto size = seq.size ();
135
+ (void )t; // warning : unused parameter when tuple is empty.
136
+ return std::array<void *, size>{static_cast <void *>(&std::get<I>(t))...};
137
+ }
138
+
139
+ // Helper function convert tuple to std::array<void*, N>
140
+ // for passing the arguments to CUDA Kernel
141
+ // NOTE: We capture tuple by reference,
142
+ // so the pointers in returned array are only valid
143
+ // till tuple is alive.
144
+ template <typename ...Args>
145
+ constexpr auto tuple_to_array (std::tuple<Args...>& extra_args) {
146
+ constexpr auto tuple_size = sizeof ...(Args);
147
+ return tuple_to_array_helper (extra_args, std::make_index_sequence<tuple_size>{});
148
+ }
149
+
150
+ // Helper function to return a vector<string>
151
+ // corresponding to the type of the arguments in parameter pack.
152
+ template <typename ... Args>
153
+ c10::SmallVector<std::string> get_extra_args_typenames () {
154
+ return {at::cuda::jit::typeName<Args>()...};
155
+ }
156
+
157
+ } // namespace
158
+
129
159
template <char const *name,
130
160
typename result_type,
131
161
typename f_inputs_type,
@@ -134,10 +164,13 @@ template<char const *name,
134
164
typename inp_calc_t ,
135
165
typename out_calc_t ,
136
166
typename loader_t ,
137
- typename storer_t >
167
+ typename storer_t ,
168
+ typename ... Args>
138
169
static inline void launch_jitted_unrolled_kernel (
139
170
DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data,
140
- inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s, bool contiguous, at::opmath_type<f_inputs_type> scalar_val) {
171
+ inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s, bool contiguous,
172
+ at::opmath_type<f_inputs_type> scalar_val,
173
+ std::tuple<Args...> extra_args) {
141
174
142
175
TORCH_INTERNAL_ASSERT (N > 0 && N <= std::numeric_limits<int32_t >::max ());
143
176
const int64_t grid = (N + block_work_size () - 1 ) / block_work_size ();
@@ -157,23 +190,32 @@ static inline void launch_jitted_unrolled_kernel(
157
190
std::string f_inputs_type_str = at::cuda::jit::typeName<f_inputs_type>();
158
191
std::string compute_type_str = at::cuda::jit::typeName<at::opmath_type<f_inputs_type>>();
159
192
std::string result_type_str = at::cuda::jit::typeName<result_type>();
193
+ c10::SmallVector<std::string> extra_args_types = get_extra_args_typenames<Args...>();
160
194
auto code = at::cuda::jit::generate_code (nTensors, f, string_name,
161
195
f_inputs_type_str, compute_type_str, result_type_str,
162
- contiguous, dynamic_casting, scalar_pos);
196
+ contiguous, dynamic_casting, scalar_pos, extra_args_types );
163
197
*fn_ptr = at::cuda::jit::jit_pwise_function (code, name);
164
198
}
165
199
}
166
200
167
- // packs args
168
- std::array<void *, 7 > args = {
169
- (void *)&N,
170
- (void *)&data,
171
- (void *)&ic,
172
- (void *)&oc,
173
- (void *)&l,
174
- (void *)&s,
175
- (void *)&scalar_val
176
- };
201
+ // pack args for kernel launch
202
+ constexpr int kernel_args = 7 ;
203
+ // size of `extra_args` is known at compile-time
204
+ constexpr auto extra_args_size = sizeof ...(Args);
205
+ void * args[kernel_args + extra_args_size];
206
+ args[0 ] = static_cast <void *>(&N);
207
+ args[1 ] = static_cast <void *>(&data);
208
+ args[2 ] = static_cast <void *>(&ic);
209
+ args[3 ] = static_cast <void *>(&oc);
210
+ args[4 ] = static_cast <void *>(&l);
211
+ args[5 ] = static_cast <void *>(&s);
212
+ args[6 ] = static_cast <void *>(&scalar_val);
213
+
214
+ auto extra_args_array = tuple_to_array (extra_args);
215
+ for (const auto i : c10::irange (extra_args_size)) {
216
+ // since 7 slots are already filled in `args`
217
+ args[i + 7 ] = extra_args_array[i];
218
+ }
177
219
178
220
at::cuda::jit::launch_jitted_pwise_function (*fn_ptr, args, grid, num_threads ());
179
221
C10_CUDA_KERNEL_LAUNCH_CHECK ();
@@ -185,9 +227,9 @@ template<
185
227
typename f_inputs_type,
186
228
int arity,
187
229
at::cuda::jit::BinaryFuncVariant scalar_pos,
188
- typename array_t >
230
+ typename array_t , typename ... Args >
189
231
static inline void launch_jitted_vectorized_kernel (DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data,
190
- at::opmath_type<f_inputs_type> scalar_val) {
232
+ at::opmath_type<f_inputs_type> scalar_val, std::tuple<Args...> extra_args ) {
191
233
TORCH_INTERNAL_ASSERT (N > 0 && N <= std::numeric_limits<int32_t >::max ());
192
234
const int64_t grid = (N + block_work_size () - 1 ) / block_work_size ();
193
235
const int vec_size = memory::jitted_can_vectorize_up_to<result_type, f_inputs_type, arity>(data);
@@ -225,10 +267,12 @@ at::opmath_type<f_inputs_type> scalar_val) {
225
267
std::string f_inputs_type_str = at::cuda::jit::typeName<f_inputs_type>();
226
268
std::string compute_type_str = at::cuda::jit::typeName<at::opmath_type<f_inputs_type>>();
227
269
std::string result_type_str = at::cuda::jit::typeName<result_type>();
270
+ c10::SmallVector<std::string> extra_args_types = get_extra_args_typenames<Args...>();
228
271
auto code = at::cuda::jit::generate_code (nTensors, f, string_name,
229
272
f_inputs_type_str, compute_type_str, result_type_str,
230
273
/* contiguous=*/ true , /* dynamic_casting=*/ false ,
231
274
scalar_pos,
275
+ extra_args_types,
232
276
vectorized, vec_size);
233
277
std::string kernel_name = vectorized ? string_name + " _vectorized" + std::to_string (vec_size) : string_name;
234
278
@@ -237,16 +281,22 @@ at::opmath_type<f_inputs_type> scalar_val) {
237
281
}
238
282
}
239
283
284
+ // size of `extra_args` is known at compile-time
285
+ constexpr auto extra_args_size = sizeof ...(Args);
286
+ auto extra_args_array = tuple_to_array (extra_args);
287
+
240
288
if (vectorized) {
241
- std::array<void *, 7 > args = {
242
- (void *)&N,
243
- (void *)&data,
244
- (void *)&scalar_val,
245
- nullptr ,
246
- nullptr ,
247
- nullptr ,
248
- nullptr
249
- };
289
+ // pack args for kernel launch
290
+ constexpr int kernel_args = 3 ;
291
+ void * args[kernel_args + extra_args_size];
292
+ args[0 ] = static_cast <void *>(&N);
293
+ args[1 ] = static_cast <void *>(&data);
294
+ args[2 ] = static_cast <void *>(&scalar_val);
295
+
296
+ for (const auto i : c10::irange (extra_args_size)) {
297
+ // since 3 slots are already filled in `args`
298
+ args[i + 3 ] = extra_args_array[i];
299
+ }
250
300
251
301
at::cuda::jit::launch_jitted_pwise_function (*fn_ptr, args, grid, num_threads ());
252
302
C10_CUDA_KERNEL_LAUNCH_CHECK ();
@@ -256,24 +306,29 @@ at::opmath_type<f_inputs_type> scalar_val) {
256
306
auto l = memory::LoadWithoutCast ();
257
307
auto s = memory::StoreWithoutCast ();
258
308
259
- std::array<void *, 7 > args = {
260
- (void *)&N,
261
- (void *)&data,
262
- (void *)&ic,
263
- (void *)&oc,
264
- (void *)&l,
265
- (void *)&s,
266
- (void *)&scalar_val
267
- };
268
-
309
+ // pack args for kernel launch
310
+ constexpr int kernel_args = 7 ;
311
+ void * args[kernel_args + extra_args_size];
312
+ args[0 ] = static_cast <void *>(&N);
313
+ args[1 ] = static_cast <void *>(&data);
314
+ args[2 ] = static_cast <void *>(&ic);
315
+ args[3 ] = static_cast <void *>(&oc);
316
+ args[4 ] = static_cast <void *>(&l);
317
+ args[5 ] = static_cast <void *>(&s);
318
+ args[6 ] = static_cast <void *>(&scalar_val);
319
+
320
+ for (const auto i : c10::irange (extra_args_size)) {
321
+ // since 7 slots are already filled in `args`
322
+ args[i + 7 ] = extra_args_array[i];
323
+ }
269
324
at::cuda::jit::launch_jitted_pwise_function (*fn_ptr, args, grid, num_threads ());
270
325
C10_CUDA_KERNEL_LAUNCH_CHECK ();
271
326
}
272
327
}
273
328
274
329
template <char const *name, typename result_type, typename compute_type, int arity,
275
- at::cuda::jit::BinaryFuncVariant scalar_pos=at::cuda::jit::BinaryFuncVariant::NoScalar>
276
- void jitted_gpu_kernel_impl (TensorIteratorBase& iter, const std::string& f, const bool dynamic_casting, compute_type scalar_val = 0 ) {
330
+ at::cuda::jit::BinaryFuncVariant scalar_pos=at::cuda::jit::BinaryFuncVariant::NoScalar, typename ... Args >
331
+ void jitted_gpu_kernel_impl (TensorIteratorBase& iter, const std::string& f, const bool dynamic_casting, compute_type scalar_val, std::tuple<Args...> extra_args ) {
277
332
TORCH_INTERNAL_ASSERT (iter.can_use_32bit_indexing ());
278
333
TORCH_INTERNAL_ASSERT (iter.ninputs () == arity);
279
334
TORCH_INTERNAL_ASSERT (iter.noutputs () == 1 );
@@ -299,7 +354,7 @@ void jitted_gpu_kernel_impl(TensorIteratorBase& iter, const std::string& f, cons
299
354
if (contiguous) {
300
355
// Case 1: no dynamic casting and contiguous
301
356
launch_jitted_vectorized_kernel<name, result_type, compute_type, arity, scalar_pos>(
302
- iter.device ().index (), numel, f, data, scalar_val);
357
+ iter.device ().index (), numel, f, data, scalar_val, extra_args );
303
358
return ;
304
359
}
305
360
@@ -310,7 +365,7 @@ void jitted_gpu_kernel_impl(TensorIteratorBase& iter, const std::string& f, cons
310
365
auto storer = memory::StoreWithoutCast ();
311
366
launch_jitted_unrolled_kernel<name, result_type, compute_type, scalar_pos>(
312
367
iter.device ().index (), numel, f, data, input_offset_calculator,
313
- output_offset_calculator, loader, storer, contiguous, scalar_val);
368
+ output_offset_calculator, loader, storer, contiguous, scalar_val, extra_args );
314
369
return ;
315
370
}
316
371
@@ -333,7 +388,7 @@ void jitted_gpu_kernel_impl(TensorIteratorBase& iter, const std::string& f, cons
333
388
auto output_offset_calculator = TrivialOffsetCalculator<1 >();
334
389
launch_jitted_unrolled_kernel<name, result_type, compute_type, scalar_pos>(
335
390
iter.device ().index (), numel, f, data, input_offset_calculator,
336
- output_offset_calculator, loader, storer, contiguous, scalar_val);
391
+ output_offset_calculator, loader, storer, contiguous, scalar_val, extra_args );
337
392
return ;
338
393
}
339
394
@@ -342,7 +397,7 @@ void jitted_gpu_kernel_impl(TensorIteratorBase& iter, const std::string& f, cons
342
397
auto output_offset_calculator = make_output_offset_calculator (iter);
343
398
launch_jitted_unrolled_kernel<name, result_type, compute_type, scalar_pos>(
344
399
iter.device ().index (), numel, f, data, input_offset_calculator,
345
- output_offset_calculator, loader, storer, contiguous, scalar_val);
400
+ output_offset_calculator, loader, storer, contiguous, scalar_val, extra_args );
346
401
}
347
402
#endif // USE_JITERATOR
348
403
0 commit comments