Skip to content

Commit f269f99

Browse files
kshitij12345pytorchmergebot
authored andcommitted
[jiterator] polygamma (pytorch#71162)
Summary: Reference: pytorch#69463 TODO: * [x] Add note regarding how to capture value and it's limitations. Pull Request resolved: pytorch#71162 Reviewed By: ngimel Differential Revision: D33697346 Pulled By: mruberry fbshipit-source-id: 0308a6c12cf4b488ab26bdae14291c1f6dbba9ab (cherry picked from commit 0d3f8c5)
1 parent c7c864b commit f269f99

File tree

8 files changed

+252
-86
lines changed

8 files changed

+252
-86
lines changed

aten/src/ATen/native/Math.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -409,10 +409,11 @@ static inline float calc_digamma(float x) {
409409
}
410410

411411
template <typename scalar_t, bool is_cuda=false>
412-
static inline C10_HOST_DEVICE scalar_t calc_polygamma(int n, scalar_t x) {
412+
static inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) {
413413
// already blocked if n <= 1
414-
return ((n % 2) ? 1.0 : -1.0) *
415-
::exp(::lgamma(static_cast<scalar_t>(n) + 1.0)) *
414+
const auto one = scalar_t{1};
415+
return ((n % 2) ? one : -one) *
416+
::exp(::lgamma(static_cast<scalar_t>(n) + one)) *
416417
zeta<scalar_t, is_cuda>(static_cast<scalar_t>(n + 1), x);
417418
}
418419

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ static void polygamma_kernel(TensorIteratorBase& iter, int64_t n) {
412412
} else {
413413
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "polygamma", [&]() {
414414
cpu_kernel(
415-
iter, [=](scalar_t a) -> scalar_t { return calc_polygamma(n, a); });
415+
iter, [=](scalar_t a) -> scalar_t { return calc_polygamma(a, n); });
416416
});
417417
}
418418
}

aten/src/ATen/native/cuda/CUDALoops.cuh

+95-40
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,36 @@ static inline void launch_vectorized_kernel(int64_t N, const func_t& f, array_t
126126

127127
// Jiterator functions are guarded behind this macro
128128
#ifdef USE_JITERATOR
129+
130+
namespace {
131+
132+
template <typename Tuple, std::size_t... I>
133+
constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence<I...> seq) {
134+
constexpr auto size = seq.size();
135+
(void)t; // warning : unused parameter when tuple is empty.
136+
return std::array<void*, size>{static_cast<void*>(&std::get<I>(t))...};
137+
}
138+
139+
// Helper function convert tuple to std::array<void*, N>
140+
// for passing the arguments to CUDA Kernel
141+
// NOTE: We capture tuple by reference,
142+
// so the pointers in returned array are only valid
143+
// till tuple is alive.
144+
template <typename ...Args>
145+
constexpr auto tuple_to_array(std::tuple<Args...>& extra_args) {
146+
constexpr auto tuple_size = sizeof...(Args);
147+
return tuple_to_array_helper(extra_args, std::make_index_sequence<tuple_size>{});
148+
}
149+
150+
// Helper function to return a vector<string>
151+
// corresponding to the type of the arguments in parameter pack.
152+
template <typename... Args>
153+
c10::SmallVector<std::string> get_extra_args_typenames() {
154+
return {at::cuda::jit::typeName<Args>()...};
155+
}
156+
157+
} // namespace
158+
129159
template<char const *name,
130160
typename result_type,
131161
typename f_inputs_type,
@@ -134,10 +164,13 @@ template<char const *name,
134164
typename inp_calc_t,
135165
typename out_calc_t,
136166
typename loader_t,
137-
typename storer_t>
167+
typename storer_t,
168+
typename ... Args>
138169
static inline void launch_jitted_unrolled_kernel(
139170
DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data,
140-
inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s, bool contiguous, at::opmath_type<f_inputs_type> scalar_val) {
171+
inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s, bool contiguous,
172+
at::opmath_type<f_inputs_type> scalar_val,
173+
std::tuple<Args...> extra_args) {
141174

142175
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
143176
const int64_t grid = (N + block_work_size() - 1) / block_work_size();
@@ -157,23 +190,32 @@ static inline void launch_jitted_unrolled_kernel(
157190
std::string f_inputs_type_str = at::cuda::jit::typeName<f_inputs_type>();
158191
std::string compute_type_str = at::cuda::jit::typeName<at::opmath_type<f_inputs_type>>();
159192
std::string result_type_str = at::cuda::jit::typeName<result_type>();
193+
c10::SmallVector<std::string> extra_args_types = get_extra_args_typenames<Args...>();
160194
auto code = at::cuda::jit::generate_code(nTensors, f, string_name,
161195
f_inputs_type_str, compute_type_str, result_type_str,
162-
contiguous, dynamic_casting, scalar_pos);
196+
contiguous, dynamic_casting, scalar_pos, extra_args_types);
163197
*fn_ptr = at::cuda::jit::jit_pwise_function(code, name);
164198
}
165199
}
166200

167-
// packs args
168-
std::array<void*, 7> args = {
169-
(void*)&N,
170-
(void*)&data,
171-
(void*)&ic,
172-
(void*)&oc,
173-
(void*)&l,
174-
(void*)&s,
175-
(void*)&scalar_val
176-
};
201+
// pack args for kernel launch
202+
constexpr int kernel_args = 7;
203+
// size of `extra_args` is known at compile-time
204+
constexpr auto extra_args_size = sizeof...(Args);
205+
void* args[kernel_args + extra_args_size];
206+
args[0] = static_cast<void*>(&N);
207+
args[1] = static_cast<void*>(&data);
208+
args[2] = static_cast<void*>(&ic);
209+
args[3] = static_cast<void*>(&oc);
210+
args[4] = static_cast<void*>(&l);
211+
args[5] = static_cast<void*>(&s);
212+
args[6] = static_cast<void*>(&scalar_val);
213+
214+
auto extra_args_array = tuple_to_array(extra_args);
215+
for (const auto i : c10::irange(extra_args_size)) {
216+
// since 7 slots are already filled in `args`
217+
args[i + 7] = extra_args_array[i];
218+
}
177219

178220
at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
179221
C10_CUDA_KERNEL_LAUNCH_CHECK();
@@ -185,9 +227,9 @@ template<
185227
typename f_inputs_type,
186228
int arity,
187229
at::cuda::jit::BinaryFuncVariant scalar_pos,
188-
typename array_t>
230+
typename array_t, typename ... Args>
189231
static inline void launch_jitted_vectorized_kernel(DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data,
190-
at::opmath_type<f_inputs_type> scalar_val) {
232+
at::opmath_type<f_inputs_type> scalar_val, std::tuple<Args...> extra_args) {
191233
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
192234
const int64_t grid = (N + block_work_size() - 1) / block_work_size();
193235
const int vec_size = memory::jitted_can_vectorize_up_to<result_type, f_inputs_type, arity>(data);
@@ -225,10 +267,12 @@ at::opmath_type<f_inputs_type> scalar_val) {
225267
std::string f_inputs_type_str = at::cuda::jit::typeName<f_inputs_type>();
226268
std::string compute_type_str = at::cuda::jit::typeName<at::opmath_type<f_inputs_type>>();
227269
std::string result_type_str = at::cuda::jit::typeName<result_type>();
270+
c10::SmallVector<std::string> extra_args_types = get_extra_args_typenames<Args...>();
228271
auto code = at::cuda::jit::generate_code(nTensors, f, string_name,
229272
f_inputs_type_str, compute_type_str, result_type_str,
230273
/*contiguous=*/true, /*dynamic_casting=*/false,
231274
scalar_pos,
275+
extra_args_types,
232276
vectorized, vec_size);
233277
std::string kernel_name = vectorized ? string_name + "_vectorized" + std::to_string(vec_size) : string_name;
234278

@@ -237,16 +281,22 @@ at::opmath_type<f_inputs_type> scalar_val) {
237281
}
238282
}
239283

284+
// size of `extra_args` is known at compile-time
285+
constexpr auto extra_args_size = sizeof...(Args);
286+
auto extra_args_array = tuple_to_array(extra_args);
287+
240288
if (vectorized) {
241-
std::array<void*, 7> args = {
242-
(void*)&N,
243-
(void*)&data,
244-
(void*)&scalar_val,
245-
nullptr,
246-
nullptr,
247-
nullptr,
248-
nullptr
249-
};
289+
// pack args for kernel launch
290+
constexpr int kernel_args = 3;
291+
void* args[kernel_args + extra_args_size];
292+
args[0] = static_cast<void*>(&N);
293+
args[1] = static_cast<void*>(&data);
294+
args[2] = static_cast<void*>(&scalar_val);
295+
296+
for (const auto i : c10::irange(extra_args_size)) {
297+
// since 3 slots are already filled in `args`
298+
args[i + 3] = extra_args_array[i];
299+
}
250300

251301
at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
252302
C10_CUDA_KERNEL_LAUNCH_CHECK();
@@ -256,24 +306,29 @@ at::opmath_type<f_inputs_type> scalar_val) {
256306
auto l = memory::LoadWithoutCast();
257307
auto s = memory::StoreWithoutCast();
258308

259-
std::array<void*, 7> args = {
260-
(void*)&N,
261-
(void*)&data,
262-
(void*)&ic,
263-
(void*)&oc,
264-
(void*)&l,
265-
(void*)&s,
266-
(void*)&scalar_val
267-
};
268-
309+
// pack args for kernel launch
310+
constexpr int kernel_args = 7;
311+
void* args[kernel_args + extra_args_size];
312+
args[0] = static_cast<void*>(&N);
313+
args[1] = static_cast<void*>(&data);
314+
args[2] = static_cast<void*>(&ic);
315+
args[3] = static_cast<void*>(&oc);
316+
args[4] = static_cast<void*>(&l);
317+
args[5] = static_cast<void*>(&s);
318+
args[6] = static_cast<void*>(&scalar_val);
319+
320+
for (const auto i : c10::irange(extra_args_size)) {
321+
// since 7 slots are already filled in `args`
322+
args[i + 7] = extra_args_array[i];
323+
}
269324
at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
270325
C10_CUDA_KERNEL_LAUNCH_CHECK();
271326
}
272327
}
273328

274329
template <char const *name, typename result_type, typename compute_type, int arity,
275-
at::cuda::jit::BinaryFuncVariant scalar_pos=at::cuda::jit::BinaryFuncVariant::NoScalar>
276-
void jitted_gpu_kernel_impl(TensorIteratorBase& iter, const std::string& f, const bool dynamic_casting, compute_type scalar_val = 0) {
330+
at::cuda::jit::BinaryFuncVariant scalar_pos=at::cuda::jit::BinaryFuncVariant::NoScalar, typename ... Args>
331+
void jitted_gpu_kernel_impl(TensorIteratorBase& iter, const std::string& f, const bool dynamic_casting, compute_type scalar_val, std::tuple<Args...> extra_args) {
277332
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
278333
TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
279334
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
@@ -299,7 +354,7 @@ void jitted_gpu_kernel_impl(TensorIteratorBase& iter, const std::string& f, cons
299354
if (contiguous) {
300355
// Case 1: no dynamic casting and contiguous
301356
launch_jitted_vectorized_kernel<name, result_type, compute_type, arity, scalar_pos>(
302-
iter.device().index(), numel, f, data, scalar_val);
357+
iter.device().index(), numel, f, data, scalar_val, extra_args);
303358
return;
304359
}
305360

@@ -310,7 +365,7 @@ void jitted_gpu_kernel_impl(TensorIteratorBase& iter, const std::string& f, cons
310365
auto storer = memory::StoreWithoutCast();
311366
launch_jitted_unrolled_kernel<name, result_type, compute_type, scalar_pos>(
312367
iter.device().index(), numel, f, data, input_offset_calculator,
313-
output_offset_calculator, loader, storer, contiguous, scalar_val);
368+
output_offset_calculator, loader, storer, contiguous, scalar_val, extra_args);
314369
return;
315370
}
316371

@@ -333,7 +388,7 @@ void jitted_gpu_kernel_impl(TensorIteratorBase& iter, const std::string& f, cons
333388
auto output_offset_calculator = TrivialOffsetCalculator<1>();
334389
launch_jitted_unrolled_kernel<name, result_type, compute_type, scalar_pos>(
335390
iter.device().index(), numel, f, data, input_offset_calculator,
336-
output_offset_calculator, loader, storer, contiguous, scalar_val);
391+
output_offset_calculator, loader, storer, contiguous, scalar_val, extra_args);
337392
return;
338393
}
339394

@@ -342,7 +397,7 @@ void jitted_gpu_kernel_impl(TensorIteratorBase& iter, const std::string& f, cons
342397
auto output_offset_calculator = make_output_offset_calculator(iter);
343398
launch_jitted_unrolled_kernel<name, result_type, compute_type, scalar_pos>(
344399
iter.device().index(), numel, f, data, input_offset_calculator,
345-
output_offset_calculator, loader, storer, contiguous, scalar_val);
400+
output_offset_calculator, loader, storer, contiguous, scalar_val, extra_args);
346401
}
347402
#endif // USE_JITERATOR
348403

aten/src/ATen/native/cuda/Loops.cuh

+66-18
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,43 @@ These improvements will likely come soon.
107107
For examples of how to use the jiterator see the i1 and gcd kernel
108108
implementations, which pass jittable strings implementing their
109109
operations instead of the typical CUDA functors.
110+
111+
To pass a runtime argument (similar to lambda captures in non-JIT kernels),
112+
we need to pass to additional arguments to `jitted_gpu_kernel` by value.
113+
Currently only primitive C++ types used for computation are valid.
114+
The order of these extra arguments should be same as the order they appear
115+
in kernel's function signature. (look at polygamma for example)
116+
117+
NOTE: One big restriction being that these arguments should be after the
118+
arguments provided by TensorIterator. Eg. While capturing `n`, where
119+
`scalar_t x` and `scalar_t y` are provided by TensorIterator,
120+
* foo(scalar_t x, scalar_t y, int n) works!
121+
* foo(int n, scalar_t x, scalar_y) doesn't work
122+
* foo(scalar_t x, int n, scalar_y) doesn't work
123+
110124
*/
111125

112126
// Entrypoint for jitted GPU kernels.
113127
// Only handles elementwise unary and binary kernels with a
114128
// common dtype and a single output.
115129
// NOTE: this assumes the op's iterator has a common_dtype.
116-
template <char const *name, typename return_type, typename f_inputs_type, int arity>
117-
void jitted_gpu_kernel(TensorIteratorBase& iter, const std::string& f,
118-
at::cuda::jit::BinaryFuncVariant scalar_pos=at::cuda::jit::BinaryFuncVariant::NoScalar,
119-
at::opmath_type<f_inputs_type> scalar_val=0) {
130+
// NOTE: We use std::tuple instead of parameter pack
131+
// for `extra_args` due to following
132+
// bug on older versions of clang
133+
// https://bugs.llvm.org/show_bug.cgi?id=23029
134+
template <
135+
char const* name,
136+
typename return_type,
137+
typename f_inputs_type,
138+
int arity,
139+
typename... Args>
140+
void jitted_gpu_kernel(
141+
TensorIteratorBase& iter,
142+
const std::string& f,
143+
at::cuda::jit::BinaryFuncVariant scalar_pos =
144+
at::cuda::jit::BinaryFuncVariant::NoScalar,
145+
at::opmath_type<f_inputs_type> scalar_val = 0,
146+
std::tuple<Args...> extra_args = std::make_tuple()) {
120147
// TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel
121148
// Maybe it could be refactored?
122149
static_assert((!std::is_same<return_type, c10::complex<double>>::value &&
@@ -137,7 +164,8 @@ at::opmath_type<f_inputs_type> scalar_val=0) {
137164

138165
if (!iter.can_use_32bit_indexing()) {
139166
for (auto& sub_iter : iter.with_32bit_indexing()) {
140-
jitted_gpu_kernel<name, return_type, f_inputs_type, arity>(sub_iter, f, scalar_pos, scalar_val);
167+
jitted_gpu_kernel<name, return_type, f_inputs_type, arity>(
168+
sub_iter, f, scalar_pos, scalar_val, extra_args);
141169
}
142170

143171
return;
@@ -170,25 +198,45 @@ at::opmath_type<f_inputs_type> scalar_val=0) {
170198
"Encountered an unsupported dtype ", dtypei, "!");
171199
}
172200
if (scalar_pos == at::cuda::jit::BinaryFuncVariant::NoScalar) {
173-
jitted_gpu_kernel_impl</*name*/ name,
174-
/*return_type=*/ return_type,
175-
/*f_inputs_type=*/ f_inputs_type,
176-
arity, at::cuda::jit::BinaryFuncVariant::NoScalar>(iter, f, needs_dynamic_casting);
201+
// NOTE: With `scalar_pos=NoScalar`,`scalar_val` is not used
202+
// for computation in the generated code and hence we pass a dummy
203+
// value of `0`.
204+
jitted_gpu_kernel_impl<
205+
/*name*/ name,
206+
/*return_type=*/return_type,
207+
/*f_inputs_type=*/f_inputs_type,
208+
arity,
209+
at::cuda::jit::BinaryFuncVariant::NoScalar>(
210+
iter, f, needs_dynamic_casting, /*scalar_val=*/0, extra_args);
177211
} else if (scalar_pos == at::cuda::jit::BinaryFuncVariant::RhsScalar) {
178-
jitted_gpu_kernel_impl</*name*/ name,
179-
/*return_type=*/ return_type,
180-
/*f_inputs_type=*/ f_inputs_type,
181-
arity, at::cuda::jit::BinaryFuncVariant::RhsScalar>(iter, f, needs_dynamic_casting, scalar_val);
212+
jitted_gpu_kernel_impl<
213+
/*name*/ name,
214+
/*return_type=*/return_type,
215+
/*f_inputs_type=*/f_inputs_type,
216+
arity,
217+
at::cuda::jit::BinaryFuncVariant::RhsScalar>(
218+
iter,
219+
f,
220+
needs_dynamic_casting,
221+
scalar_val,
222+
extra_args);
182223

183224
} else {
184-
jitted_gpu_kernel_impl</*name*/ name,
185-
/*return_type=*/ return_type,
186-
/*f_inputs_type=*/ f_inputs_type,
187-
arity, at::cuda::jit::BinaryFuncVariant::LhsScalar>(iter, f, needs_dynamic_casting, scalar_val);
188-
225+
jitted_gpu_kernel_impl<
226+
/*name*/ name,
227+
/*return_type=*/return_type,
228+
/*f_inputs_type=*/f_inputs_type,
229+
arity,
230+
at::cuda::jit::BinaryFuncVariant::LhsScalar>(
231+
iter,
232+
f,
233+
needs_dynamic_casting,
234+
scalar_val,
235+
extra_args);
189236
}
190237
}
191238

239+
// TODO: support runtime state capture similar to `jitted_gpu_kernel`.
192240
template <char const *name, typename return_type, typename f_inputs_type>
193241
void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::string& f) {
194242
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);

0 commit comments

Comments
 (0)