@@ -41,18 +41,15 @@ std::vector<Tensor> foreach_tensor_list_op(
41
41
tensor_lists.emplace_back (std::move (vec_res));
42
42
43
43
using opmath_t = at::opmath_type<T>;
44
- DISPATCH_MULTI_TENSOR_APPLY ([&]() {
45
- multi_tensor_apply<3 >(
46
- tensor_lists,
47
- BinaryOpListAlphaFunctor<
48
- T,
49
- /* depth */ 3 ,
50
- /* r_args_depth */ 2 ,
51
- /* res_arg_index */ 2 ,
52
- large_kernel_arg>(),
53
- Op<opmath_t >(),
54
- alpha.to <opmath_t >());
55
- });
44
+ multi_tensor_apply<3 >(
45
+ tensor_lists,
46
+ BinaryOpListAlphaFunctor<
47
+ T,
48
+ /* depth */ 3 ,
49
+ /* r_args_depth */ 2 ,
50
+ /* res_arg_index */ 2 >(),
51
+ Op<opmath_t >(),
52
+ alpha.to <opmath_t >());
56
53
57
54
return tensor_lists[2 ];
58
55
}
@@ -67,18 +64,15 @@ void foreach_tensor_list_op_(
67
64
tensor_lists.emplace_back (tensors2.vec ());
68
65
69
66
using opmath_t = at::opmath_type<T>;
70
- DISPATCH_MULTI_TENSOR_APPLY ([&]() {
71
- multi_tensor_apply<2 >(
72
- tensor_lists,
73
- BinaryOpListAlphaFunctor<
74
- T,
75
- /* depth */ 2 ,
76
- /* r_args_depth */ 2 ,
77
- /* res_arg_index */ 0 ,
78
- large_kernel_arg>(),
79
- Op<opmath_t >(),
80
- alpha.to <opmath_t >());
81
- });
67
+ multi_tensor_apply<2 >(
68
+ tensor_lists,
69
+ BinaryOpListAlphaFunctor<
70
+ T,
71
+ /* depth */ 2 ,
72
+ /* r_args_depth */ 2 ,
73
+ /* res_arg_index */ 0 >(),
74
+ Op<opmath_t >(),
75
+ alpha.to <opmath_t >());
82
76
increment_version (tensors1);
83
77
}
84
78
@@ -337,15 +331,13 @@ template <
337
331
typename src_t ,
338
332
int depth,
339
333
int r_args_depth,
340
- int res_arg_index,
341
- bool large_kernel_arg>
334
+ int res_arg_index>
342
335
struct CopyFunctor {
343
- static constexpr bool use_large_kernel_arg = large_kernel_arg;
344
336
static_assert (depth == 2 && r_args_depth == 1 && res_arg_index == 1 );
345
337
template <typename Op>
346
338
__device__ __forceinline__ void operator ()(
347
339
int chunk_size,
348
- TensorListMetadata<depth, large_kernel_arg >& tl,
340
+ TensorListMetadata<depth>& tl,
349
341
Op op) {
350
342
const auto tensor_loc = tl.block_to_tensor [blockIdx .x ];
351
343
const auto chunk_idx = tl.block_to_chunk [blockIdx .x ];
@@ -428,36 +420,30 @@ void foreach_tensor_copy_list_kernel_cuda_(
428
420
using opmath_t = at::opmath_type<scalar_t >;
429
421
AT_DISPATCH_SOURCE_TYPES (src[0 ].scalar_type (), " foreach_tensor_copy" , [&] {
430
422
if constexpr (std::is_same_v<scalar_t , src_t >) {
431
- DISPATCH_MULTI_TENSOR_APPLY ([&]() {
432
- multi_tensor_apply<2 >(
433
- tensor_lists,
434
- UnaryOpFunctor<
435
- scalar_t ,
436
- /* depth */ 2 ,
437
- /* r_args_depth */ 1 ,
438
- /* res_arg_index */ 1 ,
439
- large_kernel_arg>(),
440
- Copy<opmath_t , opmath_t >());
441
- });
423
+ multi_tensor_apply<2 >(
424
+ tensor_lists,
425
+ UnaryOpFunctor<
426
+ scalar_t ,
427
+ /* depth */ 2 ,
428
+ /* r_args_depth */ 1 ,
429
+ /* res_arg_index */ 1 >(),
430
+ Copy<opmath_t , opmath_t >());
442
431
} else {
443
432
// Ref:
444
433
// https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301
445
434
if (!self[0 ].is_complex () && src[0 ].is_complex ()) {
446
435
TORCH_WARN_ONCE (
447
436
" Casting complex values to real discards the imaginary part" );
448
437
}
449
- DISPATCH_MULTI_TENSOR_APPLY ([&]() {
450
- multi_tensor_apply<2 >(
451
- tensor_lists,
452
- CopyFunctor<
453
- scalar_t ,
454
- src_t ,
455
- /* depth */ 2 ,
456
- /* r_args_depth */ 1 ,
457
- /* res_arg_index */ 1 ,
458
- large_kernel_arg>(),
459
- Copy<scalar_t , src_t >());
460
- });
438
+ multi_tensor_apply<2 >(
439
+ tensor_lists,
440
+ CopyFunctor<
441
+ scalar_t ,
442
+ src_t ,
443
+ /* depth */ 2 ,
444
+ /* r_args_depth */ 1 ,
445
+ /* res_arg_index */ 1 >(),
446
+ Copy<scalar_t , src_t >());
461
447
}
462
448
});
463
449
});
0 commit comments