forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ConvShared.cpp
599 lines (531 loc) · 23.8 KB
/
ConvShared.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
#include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
#if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/ConvShared.h>
// NOTE [cuDNN API version]
//
// ConvPlaceholders.cpp contains placeholder implementation of cudnn
// convolution when cudnn is not enabled. These operators only raises
// errors, and do no real computation. This file also contains deprecated
// operators. These operators are implemented using currnet operators.
//
// cuDNN v7 and v8 have different API. ConvShared.{cpp, h} contains
// code shared by v7 and v8. Conv_v7.cpp contains implementation of
// convolution using cuDNN v7 API. Conv_v8.cpp contains implementation
// with v8 API.
//
// NOTE [ Convolution design ]
//
// cuDNN convolutions does not handle bias. Bias is handled outside.
//
// The general strategy:
//
// - cudnn_convolution (Tensor)
// Entry points for clients
//
// - cudnn_convolution_forward (TensorArg)
// Entry point, which may be reused between regular
// convolution and transposed convolution.
//
// - raw_cudnn_convolution_forward_out (Tensor)
// Function that has different implementation on Conv_v7.cpp
// and Conv_v8.cpp
//
// The raw API directly invokes CuDNN and are implemeted differently
// on cuDNN v7 and cuDNN v8
//
// There are a few reasons this should never be directly exposed
// via ATen:
//
// - It takes output as a parameter (this should be computed!)
// - It doesn't do input checking
// - It doesn't resize output (it is assumed to be correctly sized)
//
// Where does argument checking happen? Here's the division of
// responsibility:
// - Things that happen in at::Tensor
// - TensorArg allocation
// - Things that happen in TensorArg
// - Check arguments (type, GPU, shape)
namespace at { namespace native {
// ---------------------------------------------------------------------
//
// ConvolutionParams
//
// ---------------------------------------------------------------------
std::ostream& operator<<(std::ostream & out, const ConvolutionParams& params) {
out << "ConvolutionParams \n"
<< " data_type = " << cudnnTypeToString(params.dataType) << "\n"
<< " padding = " << ArrayRef<int>{params.padding} << "\n"
<< " stride = " << ArrayRef<int>{params.stride} << "\n"
<< " dilation = " << ArrayRef<int>{params.dilation} << "\n"
<< " groups = " << params.groups << "\n"
<< " deterministic = " << (params.deterministic ? "true" : "false") << "\n"
<< " allow_tf32 = " << (params.allow_tf32 ? "true" : "false") << "\n";
return out;
}
// NB: This can't be a constructor, because then ConvolutionParams
// would not be a POD anymore.
// TODO: Use TensorGeometry here instead of the entire Tensor, which we
// don't actually need. (OTOH: We can always pass in
// grad_input/grad_output, so this is not very pressing)
void setConvolutionParams(
ConvolutionParams* params,
const at::Tensor& input, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool deterministic, bool allow_tf32) {
cudnnDataType_t dataType = getCudnnDataType(input);
memset(params, 0, sizeof(ConvolutionParams));
params->device_id = at::cuda::current_device();
params->dataType = dataType;
// ASSERT(weight.dim() == input.dim())
params->input_dim = input.dim();
params->memory_format = input.suggest_memory_format();
for (int i = 0; i != params->input_dim; ++i) {
params->input_size[i] = (int) input.sizes()[i];
params->weight_size[i] = (int) weight.sizes()[i];
}
// ASSERT(padding.size() == stride.size())
// ASSERT(padding.size() == dilation.size())
for (size_t i = 0; i != padding.size(); ++i) {
params->padding[i] = padding[i];
params->stride[i] = stride[i];
params->dilation[i] = dilation[i];
}
// In principle, we shouldn't parametrize by groups for legacy
// CuDNN, but it doesn't seem worth the effort to actually do this.
params->groups = groups;
params->deterministic = deterministic;
params->allow_tf32 = allow_tf32;
}
std::string repro_from_args(const ConvolutionParams& params) {
auto pybool = [](bool b) -> const char* { return b ? "True" : "False"; };
std::string partial_dtype;
switch (params.dataType) {
case CUDNN_DATA_FLOAT: partial_dtype = "float"; break;
case CUDNN_DATA_DOUBLE: partial_dtype = "double"; break;
case CUDNN_DATA_HALF: partial_dtype = "half"; break;
default: partial_dtype = "unsupported";
}
const std::string full_dtype = "torch." + partial_dtype;
const int out_channels = params.weight_size[0];
const int in_channels = params.weight_size[1] * params.groups;
const size_t dim = params.input_dim;
const std::string channels_last_xd = dim == 4 ? "channels_last" : "channels_last_3d";
const std::string to_channels_last =
((params.memory_format == at::MemoryFormat::ChannelsLast) || (params.memory_format == at::MemoryFormat::ChannelsLast3d)) \
? ".to(memory_format=torch." + channels_last_xd + ")" : "";
std::ostringstream ss;
ss << "You can try to repro this exception using the following code snippet. ";
ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n";
ss << "import torch\n";
ss << "torch.backends.cuda.matmul.allow_tf32 = " << pybool(at::globalContext().allowTF32CuBLAS()) << "\n";
ss << "torch.backends.cudnn.benchmark = " << pybool(at::globalContext().benchmarkCuDNN()) << "\n";
ss << "torch.backends.cudnn.deterministic = " << pybool(params.deterministic) << "\n";
ss << "torch.backends.cudnn.allow_tf32 = " << pybool(params.allow_tf32) << "\n";
ss << "data = torch.randn(" << ArrayRef<int>(params.input_size, dim) << ", dtype=" << full_dtype << ", ";
ss << "device='cuda', requires_grad=True)" << to_channels_last << "\n";
ss << "net = torch.nn.Conv" << dim-2 << "d(" << in_channels << ", " << out_channels << ", ";
ss << "kernel_size=" << ArrayRef<int>(¶ms.weight_size[2], dim - 2) << ", ";
ss << "padding=" << ArrayRef<int>(params.padding, dim-2) << ", ";
ss << "stride=" << ArrayRef<int>(params.stride, dim-2) << ", ";
ss << "dilation=" << ArrayRef<int>(params.dilation, dim-2) << ", ";
ss << "groups=" << params.groups << ")\n";
ss << "net = net.cuda()." << partial_dtype << "()" << to_channels_last << "\n";
ss << "out = net(data)\n";
ss << "out.backward(torch.randn_like(out))\n";
ss << "torch.cuda.synchronize()\n\n";
return ss.str();
}
// ---------------------------------------------------------------------
//
// Checking
//
// ---------------------------------------------------------------------
// Used on pad, stride and dilation
static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
{
TORCH_CHECK(args.size() <= expected_size,
"Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
expected_size, " (while checking arguments for ", c, ")");
TORCH_CHECK(args.size() >= expected_size,
"Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
expected_size, " (while checking arguments for ", c, ")");
auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;});
if (num_negative_values > 0){
std::stringstream ss;
ss << arg_name << " should be greater than zero but got (";
std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
ss << args.back() << ")" << " (while checking arguments for " << c << ")";
AT_ERROR(ss.str());
}
}
// NOTE [ Convolution checks ]
//
// NB: For many call sites, it is not strictly necessary to check all of
// these relationships (for example, for forward convolution, we compute
// the size of output ourselves, so we don't actually need to check
// output. However, writing a single function that does everything
// means we get to reuse it for both forwards and all backwards
// variants, even when the set of "real" inputs varies. The magic of
// relational computing!
//
// (There is one downside, which is that it is slightly harder to write
// error messages which are able to distinguish between real inputs
// (which the user can change) and computed inputs (which the user can
// only indirectly affect). It would be an interesting exercise to
// come up with a general framework to handle such situations.)
static void convolution_shape_check(
CheckedFrom c,
const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
{
check_args(c, padding, input->dim() - 2, "padding");
check_args(c, stride, padding.size(), "stride");
check_args(c, dilation, padding.size(), "dilation");
// Input
checkDimRange(c, input, 3, 6 /* exclusive */);
checkSize(c, input, input_channels_dim, weight->size(1) * groups);
// Weight
checkSameDim(c, input, weight);
// TODO: check that output->size() matches output_sizes
// TODO: check that weight matches output->sizes()
checkSameDim(c, input, output);
}
// ---------------------------------------------------------------------
//
// Convolution forward / Transposed convolution backward
//
// ---------------------------------------------------------------------
Tensor cudnn_convolution_forward(
CheckedFrom c,
const TensorArg& input, const TensorArg& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
{
checkAllSameType(c, {input, weight});
checkAllSameGPU(c, {input, weight});
auto memory_format = at::MemoryFormat::Contiguous;
if (cudnn_conv_use_channels_last(*input, *weight)) {
memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
}
auto output_t = at::native::empty_cuda(
conv_output_size(input->sizes(), weight->sizes(),
padding, stride, dilation),
/*dtype=*/input->scalar_type(),
/*layout=*/c10::nullopt,
/*device=*/kCUDA,
/*pin_memory=*/c10::nullopt,
/*memory_format=*/memory_format);
if (output_t.numel() == 0) {
return output_t;
}
// Avoid ambiguity of "output" when this is being used as backwards
TensorArg output{ output_t, "result", 0 };
convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);
// See #4500
Tensor weight_contig = weight->contiguous(memory_format);
// Make sure that NC11 strides follow formula
weight_contig.resize_(weight_contig.sizes(), memory_format);
Tensor input_contig = input->contiguous(memory_format);
input_contig.resize_(input_contig.sizes(), memory_format);
raw_cudnn_convolution_forward_out(
*output, input_contig, weight_contig,
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
return *output;
}
Tensor cudnn_convolution(
const Tensor& input_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32)
{
TensorArg input { input_t, "input", 1 },
weight { weight_t, "weight", 2 };
CheckedFrom c = "cudnn_convolution";
auto output_t = cudnn_convolution_forward(
c, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
return output_t;
}
// NB: output_padding not needed here, as there is no ambiguity to
// resolve
Tensor cudnn_convolution_transpose_backward_input(
const Tensor& grad_output_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32)
{
TensorArg grad_output { grad_output_t, "grad_output", 1 },
weight { weight_t, "weight", 2 };
return cudnn_convolution_forward(
"cudnn_convolution_transpose_backward_input",
grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
}
std::tuple<at::Tensor,at::Tensor> cudnn_convolution_transpose_backward(
const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32, std::array<bool,2> output_mask) {
Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
Tensor grad_input, grad_weight;
if (output_mask[0]) {
grad_input = at::cudnn_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
}
if (output_mask[1]) {
grad_weight = at::cudnn_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
}
return std::tuple<Tensor,Tensor>{grad_input, grad_weight};
}
// ---------------------------------------------------------------------
//
// Convolution backward / Transposed convolution forward
//
// ---------------------------------------------------------------------
// NOTE [ Backward vs transpose convolutions ]
//
// Backward and transpose are algorithmically equivalent, but they
// compute their geometry differently. In a backwards, you knew what
// the original size of the input tensor was, so you can cache that
// geometry and fill it directly. In transposed convolution, it is
// more conventional to not explicitly specify the output (previously
// input) size, and compute it. This, however, leaves a degree of
// freedom; this degree of freedom is resolved using the
// output_padding parameter. Both of these interfaces are equivalent,
// but they are differently convenient depending on the use case.
Tensor cudnn_convolution_backward_input(
CheckedFrom c,
IntArrayRef input_size, const TensorArg& grad_output, const TensorArg& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
{
checkAllSameType(c, {grad_output, weight});
checkAllSameGPU(c, {grad_output, weight});
auto memory_format = at::MemoryFormat::Contiguous;
if (cudnn_conv_use_channels_last(*grad_output, *weight)){
memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
}
auto grad_input_t = at::native::empty_cuda(
input_size,
/*dtype=*/grad_output->scalar_type(),
/*layout=*/c10::nullopt,
/*device=*/kCUDA,
/*pin_memory=*/c10::nullopt,
/*memory_format=*/memory_format);
// Avoid "grad_input" when this is being used as transposed convolution
TensorArg grad_input{ grad_input_t, "result", 0 };
convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups);
// See #4500
Tensor weight_contig = weight->contiguous(memory_format);
// Make sure that NC11 strides follow formula
weight_contig.resize_(weight_contig.sizes(), memory_format);
Tensor grad_output_contig = grad_output->contiguous(memory_format);
grad_output_contig.resize_(grad_output_contig.sizes(), memory_format);
raw_cudnn_convolution_backward_input_out(
*grad_input, grad_output_contig, weight_contig,
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
return *grad_input;
}
Tensor cudnn_convolution_transpose_forward(
CheckedFrom c,
const TensorArg& grad_output, const TensorArg& weight,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
{
auto input_size = conv_input_size(grad_output->sizes(), weight->sizes(),
padding, output_padding, stride, dilation, groups);
return cudnn_convolution_backward_input(c, input_size, grad_output, weight,
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
}
Tensor cudnn_convolution_backward_input(
IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
{
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
weight{ weight_t, "weight", 2 };
return cudnn_convolution_backward_input(
"cudnn_convolution_backward_input",
input_size, grad_output, weight,
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
}
std::tuple<at::Tensor,at::Tensor> cudnn_convolution_backward(
const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32, std::array<bool,2> output_mask) {
Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
Tensor grad_input, grad_weight;
if (input.numel() == 0) {
if (output_mask[0]) {
grad_input = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
if (output_mask[1]) {
grad_weight = at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
} else {
if (output_mask[0]) {
grad_input = at::cudnn_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
}
if (output_mask[1]) {
grad_weight = at::cudnn_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
}
}
return std::tuple<Tensor,Tensor>{grad_input, grad_weight};
}
Tensor cudnn_convolution_transpose(
const Tensor& input_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32)
{
TensorArg input { input_t, "input", 1 },
weight { weight_t, "weight", 2 };
CheckedFrom c = "cudnn_convolution_transpose";
auto output_t = cudnn_convolution_transpose_forward(
c, input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
return output_t;
}
// ---------------------------------------------------------------------
//
// Convolution backward (weight)
//
// ---------------------------------------------------------------------
Tensor cudnn_convolution_backward_weight(
CheckedFrom c,
IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
{
auto layout = at::MemoryFormat::Contiguous;
if (cudnn_conv_use_channels_last(input_t, grad_output_t)){
layout = (input_t.ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
}
Tensor grad_output_contig_t = grad_output_t.contiguous(layout);
// Make sure that NC11 strides follow formula
grad_output_contig_t.resize_(grad_output_contig_t.sizes(), layout);
TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 };
Tensor input_contig_t = input_t.contiguous(layout);
input_contig_t.resize_(input_contig_t.sizes(), layout);
TensorArg input{ input_contig_t, "input", 2};
checkAllSameType(c, {grad_output_contig, input});
checkAllSameGPU(c, {grad_output_contig, input});
auto grad_weight_t = at::empty(weight_size, grad_output_contig->options(), layout);
// For uniformity with everything else, although it seems grad_weight
// would be unambiguous too.
TensorArg grad_weight{ grad_weight_t, "result", 0 };
convolution_shape_check(c, input, grad_weight, grad_output_contig, padding, stride, dilation, groups);
raw_cudnn_convolution_backward_weight_out(
*grad_weight, *grad_output_contig, *input,
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
return grad_weight_t;
}
Tensor cudnn_convolution_backward_weight(
IntArrayRef weight_size,
const Tensor& grad_output_t,
const Tensor& input_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
{
return cudnn_convolution_backward_weight(
"cudnn_convolution_backward_weight",
weight_size, grad_output_t, input_t,
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
}
Tensor cudnn_convolution_transpose_backward_weight(
IntArrayRef weight_size,
const Tensor& grad_output_t,
const Tensor& input_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
{
return cudnn_convolution_backward_weight(
"cudnn_convolution_backward_weight",
weight_size, input_t, grad_output_t,
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
}
Tensor cudnn_convolution_relu(
const Tensor& input_t,
const Tensor& weight_t,
const c10::optional<Tensor>& bias_t,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
int64_t groups) {
// FuseFrozenConvAddRelu performs some tensor shape checking
auto output_t = at::native::empty_cuda(
conv_output_size(
input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
/*dtype=*/input_t.scalar_type(),
/*layout=*/c10::nullopt,
/*device=*/kCUDA,
/*pin_memory=*/c10::nullopt,
/*memory_format=*/at::MemoryFormat::Contiguous);
if (output_t.numel() == 0) {
return output_t;
}
raw_cudnn_convolution_add_relu_out(
output_t,
input_t,
weight_t,
output_t, // use output_t as z to satisfy CUDNN API
0, // alpha
bias_t.has_value()
? bias_t.value()
: at::native::zeros(
{output_t.size(1)},
optTypeMetaToScalarType(output_t.options().dtype_opt()),
output_t.options().layout_opt(),
output_t.options().device_opt(),
output_t.options().pinned_memory_opt()),
stride,
padding,
dilation,
groups,
false, // benchmark
false, // deterministic
input_t.dim() == 4 // enable allow_tf32 for conv2d
);
return output_t;
}
Tensor cudnn_convolution_add_relu(
const Tensor& input_t,
const Tensor& weight_t,
const Tensor& z_t,
const c10::optional<Scalar>& alpha,
const c10::optional<Tensor>& bias_t,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
int64_t groups) {
// FuseFrozenConvAddRelu performs some tensor shape checking
auto output_t = at::native::empty_cuda(
conv_output_size(
input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
/*dtype=*/input_t.scalar_type(),
/*layout=*/c10::nullopt,
/*device=*/kCUDA,
/*pin_memory=*/c10::nullopt,
/*memory_format=*/at::MemoryFormat::Contiguous);
if (output_t.numel() == 0) {
return output_t;
}
raw_cudnn_convolution_add_relu_out(
output_t,
input_t,
weight_t,
z_t,
alpha.has_value() ? alpha.value().to<float>() : 1.0,
bias_t.has_value()
? bias_t.value()
: at::native::zeros(
{output_t.size(1)},
optTypeMetaToScalarType(output_t.options().dtype_opt()),
output_t.options().layout_opt(),
output_t.options().device_opt(),
output_t.options().pinned_memory_opt()),
stride,
padding,
dilation,
groups,
false, // benchmark
false, // deterministic
input_t.dim() == 4 // enable allow_tf32 for conv2d
);
return output_t;
}
}}
#endif // AT_CUDNN_ENABLED