Skip to content

Commit 741d52c

Browse files
Revert "Add support for 32KB multi_tensor_apply kernel arguments (pytorch#134373)"
This reverts commit 08184aa. Reverted pytorch#134373 on behalf of https://github.com/drisspg due to See pytorch#135126 for more details ([comment](pytorch#134373 (comment)))
1 parent dd7cd18 commit 741d52c

19 files changed

+470
-852
lines changed

aten/src/ATen/native/cuda/AmpKernels.cu

+15-18
Original file line numberDiff line numberDiff line change
@@ -157,24 +157,21 @@ void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads,
157157
using opmath_t = at::opmath_type<scalar_t>;
158158

159159
// multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly.
160-
DISPATCH_MULTI_TENSOR_APPLY([&]() {
161-
multi_tensor_apply<1>(tensor_lists,
162-
UnaryOpFunctor<scalar_t,
163-
/* depth */ 1,
164-
/* r_args_depth */ 1,
165-
/* res_arg_index */ 0,
166-
large_kernel_arg>(),
167-
[found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t {
168-
// There is a slight asymmetry here with the TensorIterator kernel above.
169-
// MTA Functors ensure val comes in as opmath_t rather than scalar_t.
170-
if (!isfinite_ensure_cuda_math(val)) {
171-
*found_inf_ptr = 1.f;
172-
}
173-
// Every thread accesses inv_scale, but it will hit in cache.
174-
const auto inv_scale_val = *inv_scale_ptr;
175-
return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
176-
});
177-
});
160+
multi_tensor_apply<1>(tensor_lists,
161+
UnaryOpFunctor<scalar_t,
162+
/* depth */ 1,
163+
/* r_args_depth */ 1,
164+
/* res_arg_index */ 0>(),
165+
[found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t {
166+
// There is a slight asymmetry here with the TensorIterator kernel above.
167+
// MTA Functors ensure val comes in as opmath_t rather than scalar_t.
168+
if (!isfinite_ensure_cuda_math(val)) {
169+
*found_inf_ptr = 1.f;
170+
}
171+
// Every thread accesses inv_scale, but it will hit in cache.
172+
const auto inv_scale_val = *inv_scale_ptr;
173+
return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
174+
});
178175
});
179176
}
180177

aten/src/ATen/native/cuda/ForeachBinaryOpList.cu

+37-51
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,15 @@ std::vector<Tensor> foreach_tensor_list_op(
4141
tensor_lists.emplace_back(std::move(vec_res));
4242

4343
using opmath_t = at::opmath_type<T>;
44-
DISPATCH_MULTI_TENSOR_APPLY([&]() {
45-
multi_tensor_apply<3>(
46-
tensor_lists,
47-
BinaryOpListAlphaFunctor<
48-
T,
49-
/* depth */ 3,
50-
/* r_args_depth */ 2,
51-
/* res_arg_index */ 2,
52-
large_kernel_arg>(),
53-
Op<opmath_t>(),
54-
alpha.to<opmath_t>());
55-
});
44+
multi_tensor_apply<3>(
45+
tensor_lists,
46+
BinaryOpListAlphaFunctor<
47+
T,
48+
/* depth */ 3,
49+
/* r_args_depth */ 2,
50+
/* res_arg_index */ 2>(),
51+
Op<opmath_t>(),
52+
alpha.to<opmath_t>());
5653

5754
return tensor_lists[2];
5855
}
@@ -67,18 +64,15 @@ void foreach_tensor_list_op_(
6764
tensor_lists.emplace_back(tensors2.vec());
6865

6966
using opmath_t = at::opmath_type<T>;
70-
DISPATCH_MULTI_TENSOR_APPLY([&]() {
71-
multi_tensor_apply<2>(
72-
tensor_lists,
73-
BinaryOpListAlphaFunctor<
74-
T,
75-
/* depth */ 2,
76-
/* r_args_depth */ 2,
77-
/* res_arg_index */ 0,
78-
large_kernel_arg>(),
79-
Op<opmath_t>(),
80-
alpha.to<opmath_t>());
81-
});
67+
multi_tensor_apply<2>(
68+
tensor_lists,
69+
BinaryOpListAlphaFunctor<
70+
T,
71+
/* depth */ 2,
72+
/* r_args_depth */ 2,
73+
/* res_arg_index */ 0>(),
74+
Op<opmath_t>(),
75+
alpha.to<opmath_t>());
8276
increment_version(tensors1);
8377
}
8478

@@ -337,15 +331,13 @@ template <
337331
typename src_t,
338332
int depth,
339333
int r_args_depth,
340-
int res_arg_index,
341-
bool large_kernel_arg>
334+
int res_arg_index>
342335
struct CopyFunctor {
343-
static constexpr bool use_large_kernel_arg = large_kernel_arg;
344336
static_assert(depth == 2 && r_args_depth == 1 && res_arg_index == 1);
345337
template <typename Op>
346338
__device__ __forceinline__ void operator()(
347339
int chunk_size,
348-
TensorListMetadata<depth, large_kernel_arg>& tl,
340+
TensorListMetadata<depth>& tl,
349341
Op op) {
350342
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
351343
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@@ -428,36 +420,30 @@ void foreach_tensor_copy_list_kernel_cuda_(
428420
using opmath_t = at::opmath_type<scalar_t>;
429421
AT_DISPATCH_SOURCE_TYPES(src[0].scalar_type(), "foreach_tensor_copy", [&] {
430422
if constexpr (std::is_same_v<scalar_t, src_t>) {
431-
DISPATCH_MULTI_TENSOR_APPLY([&]() {
432-
multi_tensor_apply<2>(
433-
tensor_lists,
434-
UnaryOpFunctor<
435-
scalar_t,
436-
/* depth */ 2,
437-
/* r_args_depth */ 1,
438-
/* res_arg_index */ 1,
439-
large_kernel_arg>(),
440-
Copy<opmath_t, opmath_t>());
441-
});
423+
multi_tensor_apply<2>(
424+
tensor_lists,
425+
UnaryOpFunctor<
426+
scalar_t,
427+
/* depth */ 2,
428+
/* r_args_depth */ 1,
429+
/* res_arg_index */ 1>(),
430+
Copy<opmath_t, opmath_t>());
442431
} else {
443432
// Ref:
444433
// https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301
445434
if (!self[0].is_complex() && src[0].is_complex()) {
446435
TORCH_WARN_ONCE(
447436
"Casting complex values to real discards the imaginary part");
448437
}
449-
DISPATCH_MULTI_TENSOR_APPLY([&]() {
450-
multi_tensor_apply<2>(
451-
tensor_lists,
452-
CopyFunctor<
453-
scalar_t,
454-
src_t,
455-
/* depth */ 2,
456-
/* r_args_depth */ 1,
457-
/* res_arg_index */ 1,
458-
large_kernel_arg>(),
459-
Copy<scalar_t, src_t>());
460-
});
438+
multi_tensor_apply<2>(
439+
tensor_lists,
440+
CopyFunctor<
441+
scalar_t,
442+
src_t,
443+
/* depth */ 2,
444+
/* r_args_depth */ 1,
445+
/* res_arg_index */ 1>(),
446+
Copy<scalar_t, src_t>());
461447
}
462448
});
463449
});

aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu

+18-24
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,15 @@ std::vector<Tensor> foreach_binary_op(
3636
tensor_lists.emplace_back(std::move(vec_res));
3737

3838
using opmath_t = at::opmath_type<T>;
39-
DISPATCH_MULTI_TENSOR_APPLY([&]() {
40-
multi_tensor_apply<2>(
41-
tensor_lists,
42-
BinaryOpScalarFunctor<
43-
T,
44-
/* depth */ 2,
45-
/* r_args_depth */ 1,
46-
/* res_arg_index */ 1,
47-
large_kernel_arg>(),
48-
Op<opmath_t>(),
49-
scalar.to<opmath_t>());
50-
});
39+
multi_tensor_apply<2>(
40+
tensor_lists,
41+
BinaryOpScalarFunctor<
42+
T,
43+
/* depth */ 2,
44+
/* r_args_depth */ 1,
45+
/* res_arg_index */ 1>(),
46+
Op<opmath_t>(),
47+
scalar.to<opmath_t>());
5148
return tensor_lists[1];
5249
}
5350

@@ -57,18 +54,15 @@ void foreach_binary_op_(TensorList tensors, const Scalar& scalar) {
5754
tensor_lists.emplace_back(tensors.vec());
5855

5956
using opmath_t = at::opmath_type<T>;
60-
DISPATCH_MULTI_TENSOR_APPLY([&]() {
61-
multi_tensor_apply<1>(
62-
tensor_lists,
63-
BinaryOpScalarFunctor<
64-
T,
65-
/* depth */ 1,
66-
/* r_args_depth */ 1,
67-
/* res_arg_index */ 0,
68-
large_kernel_arg>(),
69-
Op<opmath_t>(),
70-
scalar.to<opmath_t>());
71-
});
57+
multi_tensor_apply<1>(
58+
tensor_lists,
59+
BinaryOpScalarFunctor<
60+
T,
61+
/* depth */ 1,
62+
/* r_args_depth */ 1,
63+
/* res_arg_index */ 0>(),
64+
Op<opmath_t>(),
65+
scalar.to<opmath_t>());
7266
increment_version(tensors);
7367
}
7468

aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu

+19-25
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,16 @@ std::vector<Tensor> foreach_binary_op(
3636
tensor_lists.emplace_back(vec_res);
3737

3838
using opmath_t = at::opmath_type<T>;
39-
DISPATCH_MULTI_TENSOR_APPLY([&]() {
40-
multi_tensor_apply<2, opmath_t>(
41-
tensor_lists,
42-
scalars,
43-
BinaryOpScalarListFunctor<
44-
T,
45-
/* depth */ 2,
46-
/* r_args_depth */ 1,
47-
/* res_arg_index */ 1,
48-
large_kernel_arg>(),
49-
50-
Op<opmath_t>());
51-
});
39+
multi_tensor_apply<2, opmath_t>(
40+
tensor_lists,
41+
scalars,
42+
BinaryOpScalarListFunctor<
43+
T,
44+
/* depth */ 2,
45+
/* r_args_depth */ 1,
46+
/* res_arg_index */ 1>(),
47+
48+
Op<opmath_t>());
5249
return tensor_lists[1];
5350
}
5451

@@ -58,18 +55,15 @@ void foreach_binary_op_(TensorList tensors, at::ArrayRef<Scalar> scalars) {
5855
tensor_lists.emplace_back(tensors.vec());
5956

6057
using opmath_t = at::opmath_type<T>;
61-
DISPATCH_MULTI_TENSOR_APPLY([&]() {
62-
multi_tensor_apply<1, opmath_t>(
63-
tensor_lists,
64-
scalars,
65-
BinaryOpScalarListFunctor<
66-
T,
67-
/* depth */ 1,
68-
/* r_args_depth */ 1,
69-
/* res_arg_index */ 0,
70-
large_kernel_arg>(),
71-
Op<opmath_t>());
72-
});
58+
multi_tensor_apply<1, opmath_t>(
59+
tensor_lists,
60+
scalars,
61+
BinaryOpScalarListFunctor<
62+
T,
63+
/* depth */ 1,
64+
/* r_args_depth */ 1,
65+
/* res_arg_index */ 0>(),
66+
Op<opmath_t>());
7367
increment_version(tensors);
7468
}
7569

aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu

+20-26
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,16 @@ std::vector<Tensor> foreach_binary_op(
4646
tensor_lists.emplace_back(std::move(vec_res));
4747

4848
using opmath_t = at::opmath_type<T>;
49-
DISPATCH_MULTI_TENSOR_APPLY([&]() {
50-
multi_tensor_apply<2>(
51-
tensor_lists,
52-
BinaryOpScalarTensorFunctor<
53-
T,
54-
/* depth */ 2,
55-
/* r_args_depth */ 1,
56-
/* res_arg_index */ 1,
57-
large_kernel_arg>(),
58-
Op<opmath_t>(),
59-
scalar.data_ptr<T>(),
60-
alpha.to<opmath_t>());
61-
});
49+
multi_tensor_apply<2>(
50+
tensor_lists,
51+
BinaryOpScalarTensorFunctor<
52+
T,
53+
/* depth */ 2,
54+
/* r_args_depth */ 1,
55+
/* res_arg_index */ 1>(),
56+
Op<opmath_t>(),
57+
scalar.data_ptr<T>(),
58+
alpha.to<opmath_t>());
6259
return tensor_lists[1];
6360
}
6461

@@ -84,19 +81,16 @@ void foreach_binary_op_(
8481
tensor_lists.emplace_back(tensors.vec());
8582

8683
using opmath_t = at::opmath_type<T>;
87-
DISPATCH_MULTI_TENSOR_APPLY([&]() {
88-
multi_tensor_apply<1>(
89-
tensor_lists,
90-
BinaryOpScalarTensorFunctor<
91-
T,
92-
/* depth */ 1,
93-
/* r_args_depth */ 1,
94-
/* res_arg_index */ 0,
95-
large_kernel_arg>(),
96-
Op<opmath_t>(),
97-
scalar.data_ptr<T>(),
98-
alpha.to<opmath_t>());
99-
});
84+
multi_tensor_apply<1>(
85+
tensor_lists,
86+
BinaryOpScalarTensorFunctor<
87+
T,
88+
/* depth */ 1,
89+
/* r_args_depth */ 1,
90+
/* res_arg_index */ 0>(),
91+
Op<opmath_t>(),
92+
scalar.data_ptr<T>(),
93+
alpha.to<opmath_t>());
10094
increment_version(tensors);
10195
}
10296

0 commit comments

Comments
 (0)