forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SpatialUpSamplingBicubic.cu
157 lines (137 loc) · 4.6 KB
/
SpatialUpSamplingBicubic.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#include <THCUNN/THCUNN.h>
#include <THC/THCTensor.hpp>
#include <THCUNN/common.h>
#include <THCUNN/upsampling.h>
#include <THC/THCDeviceTensor.cuh>
#include <THC/THCDeviceTensorUtils.cuh>
#include <THC/THCDeviceUtils.cuh>
#include <TH/THHalf.h>
#include <THCUNN/THCHalfAutoNumerics.cuh>
#include <THC/THCAtomics.cuh>
template<typename Dtype, typename Acctype>
__global__ void bicubic_interp2d_kernel(
const int num_elements,
const Acctype height_scale,
const Acctype width_scale,
const THCDeviceTensor<Dtype, 4> in_data,
THCDeviceTensor<Dtype, 4> out_data
) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
const int batchsize = in_data.getSize(0);
const int channels = in_data.getSize(1);
const int input_height = in_data.getSize(2);
const int input_width = in_data.getSize(3);
const int output_height = out_data.getSize(2);
const int output_width = out_data.getSize(3);
if (index >= num_elements) {
return;
}
// Special case: input and output are the same size, just copy
const int output_x = index % output_width;
const int output_y = index / output_width;
if (input_height == output_height && input_width == output_width) {
for (int n = 0; n < batchsize; n++){
for (int c = 0; c < channels; c++) {
const Dtype val = in_data[n][c][output_y][output_x];
out_data[n][c][output_x][output_y] = val;
}
}
return;
}
// Interpolation kernel
Acctype real_x = width_scale * output_x;
int in_x = real_x;
Acctype t_x = real_x - in_x;
Acctype real_y = height_scale * output_y;
int in_y = real_y;
Acctype t_y = real_y - in_y;
for (int n = 0; n < batchsize ; n++) {
for (int c = 0; c < channels; c++) {
Acctype coefficients[4];
for (int k = 0; k < 4; k++) {
coefficients[k] = cubic_interp1d(
upsampling_get_value_bounded<Dtype>(
in_data, c, n, input_width, input_height, in_x - 1, in_y - 1 + k),
upsampling_get_value_bounded<Dtype>(
in_data, c, n, input_width, input_height, in_x + 0, in_y - 1 + k),
upsampling_get_value_bounded<Dtype>(
in_data, c, n, input_width, input_height, in_x + 1, in_y - 1 + k),
upsampling_get_value_bounded<Dtype>(
in_data, c, n, input_width, input_height, in_x + 2, in_y - 1 + k),
t_x
);
}
out_data[n][c][output_y][output_x] = ScalarConvert<Acctype, Dtype>::to(cubic_interp1d(
coefficients[0],
coefficients[1],
coefficients[2],
coefficients[3],
t_y
));
}
}
}
// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename Dtype, typename Acctype>
__global__ void bicubic_interp2d_backward_kernel(
const int num_elements,
const Acctype height_scale,
const Acctype width_scale,
const bool align_corners,
THCDeviceTensor<Dtype, 4> in_data,
const THCDeviceTensor<Dtype, 4> out_data
){
int index = threadIdx.x + blockIdx.x * blockDim.x;
const int batchsize = in_data.getSize(0);
const int channels = in_data.getSize(1);
const int input_height = in_data.getSize(2);
const int input_width = in_data.getSize(3);
const int output_height = out_data.getSize(2);
const int output_width = out_data.getSize(3);
if (index >= num_elements) {
return;
}
const int output_x = index % output_width;
const int output_y = index / output_width;
// special case: output_xust copy
if (input_height == output_height && input_width == output_width) {
for (int n = 0; n < batchsize ; n++){
for (int c = 0; c < channels; ++c) {
const Dtype val = out_data[n][c][output_y][output_x];
in_data[n][c][output_y][output_x] += val;
}
}
return;
}
Acctype real_x = width_scale * output_x;
int input_x = real_x;
Acctype t_x = real_x - input_x;
Acctype real_y = height_scale * output_y;
int input_y = real_y;
Acctype t_y = real_y - input_y;
Acctype x_coeffs[4];
Acctype y_coeffs[4];
get_cubic_upsampling_coefficients(x_coeffs, t_x);
get_cubic_upsampling_coefficients(y_coeffs, t_y);
for (int n = 0; n < batchsize ; n++){
for (int c = 0; c < channels; ++c) {
Dtype out_value = out_data[n][c][output_y][output_x];
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
upsampling_increment_value_bounded<Dtype, Acctype>(
in_data,
c,
n,
input_width,
input_height,
input_x - 1 + j,
input_y - 1 + i,
out_value * y_coeffs[i] * x_coeffs[j]
);
}
}
}
}
}
#include <THCUNN/generic/SpatialUpSamplingBicubic.cu>
#include <THC/THCGenerateFloatTypes.h>