1
- /*
2
- * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <[email protected] >
1
+ /****************************************************************************
2
+ * apps/mlearning/tflite-micro/operators/neon/arm_convolve_s8.c
3
+ *
4
+ * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or
5
+ * its affiliates <[email protected] >
3
6
*
4
7
* SPDX-License-Identifier: Apache-2.0
5
8
*
14
17
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
18
* See the License for the specific language governing permissions and
16
19
* limitations under the License.
17
- */
20
+ ****************************************************************************/
21
+
22
+ /****************************************************************************
23
+ * Included Files
24
+ ****************************************************************************/
18
25
19
26
#include <arm_neon.h>
20
27
#include "arm_nnfunctions.h"
21
28
#include "arm_nnsupportfunctions.h"
22
29
23
- /**
24
- * @ingroup Public
25
- */
26
-
27
- /**
28
- * @addtogroup NNConv
29
- * @{
30
- */
30
+ /****************************************************************************
31
+ * Public Functions
32
+ ****************************************************************************/
31
33
32
- /*
33
- * Basic s8 convolution function.
34
- *
35
- * Refer header file for details. Optimal use case for the DSP/MVE implementation is when input and output channels
36
- * are multiples of 4 or atleast greater than 4.
34
+ /* Basic s8 convolution function.
37
35
*
36
+ * Refer header file for details. Optimal use case for the DSP/MVE
37
+ * implementation is when input and output channels are multiples of 4 or
38
+ * atleast greater than 4.
38
39
*/
39
- arm_cmsis_nn_status arm_convolve_s8 (const cmsis_nn_context * ctx ,
40
- const cmsis_nn_conv_params * conv_params ,
41
- const cmsis_nn_per_channel_quant_params * quant_params ,
42
- const cmsis_nn_dims * input_dims ,
43
- const int8_t * input_data ,
44
- const cmsis_nn_dims * filter_dims ,
45
- const int8_t * filter_data ,
46
- const cmsis_nn_dims * bias_dims ,
47
- const int32_t * bias_data ,
48
- const cmsis_nn_dims * output_dims ,
49
- int8_t * output_data )
40
+
41
+ arm_cmsis_nn_status
42
+ arm_convolve_s8 (const cmsis_nn_context * ctx ,
43
+ const cmsis_nn_conv_params * conv_params ,
44
+ const cmsis_nn_per_channel_quant_params * quant_params ,
45
+ const cmsis_nn_dims * input_dims ,
46
+ const int8_t * input_data ,
47
+ const cmsis_nn_dims * filter_dims ,
48
+ const int8_t * filter_data ,
49
+ const cmsis_nn_dims * bias_dims ,
50
+ const int32_t * bias_data ,
51
+ const cmsis_nn_dims * output_dims ,
52
+ int8_t * output_data )
50
53
{
51
- (void )bias_dims ;
54
+ (void )bias_dims ;
52
55
53
- if (ctx -> buf == NULL )
56
+ if (ctx -> buf == NULL )
54
57
{
55
- return ARM_CMSIS_NN_ARG_ERROR ;
58
+ return ARM_CMSIS_NN_ARG_ERROR ;
56
59
}
57
- int16_t * buffer_a = (int16_t * )ctx -> buf ;
58
-
59
- const int32_t input_batches = input_dims -> n ;
60
- const uint16_t input_x = input_dims -> w ;
61
- const uint16_t input_y = input_dims -> h ;
62
- const uint16_t input_ch = input_dims -> c ;
63
- const uint16_t kernel_x = filter_dims -> w ;
64
- const uint16_t kernel_y = filter_dims -> h ;
65
- const uint16_t output_x = output_dims -> w ;
66
- const uint16_t output_y = output_dims -> h ;
67
- const uint16_t output_ch = output_dims -> c ;
68
-
69
- const uint16_t pad_x = conv_params -> padding .w ;
70
- const uint16_t pad_y = conv_params -> padding .h ;
71
- const uint16_t stride_x = conv_params -> stride .w ;
72
- const uint16_t stride_y = conv_params -> stride .h ;
73
- const int32_t dilation_x = conv_params -> dilation .w ;
74
- const int32_t dilation_y = conv_params -> dilation .h ;
75
- const int32_t out_offset = conv_params -> output_offset ;
76
- const int32_t out_activation_min = conv_params -> activation .min ;
77
- const int32_t out_activation_max = conv_params -> activation .max ;
78
- const int32_t rhs_cols = kernel_x * kernel_y * input_ch ;
79
- const int32_t input_offset = conv_params -> input_offset ;
80
-
81
- int32_t * output_mult = quant_params -> multiplier ;
82
- int32_t * output_shift = quant_params -> shift ;
83
-
84
- int i_batch ;
85
- for (i_batch = 0 ; i_batch < input_batches ; i_batch ++ )
60
+
61
+ int16_t * buffer_a = (int16_t * )ctx -> buf ;
62
+
63
+ const int32_t input_batches = input_dims -> n ;
64
+ const uint16_t input_x = input_dims -> w ;
65
+ const uint16_t input_y = input_dims -> h ;
66
+ const uint16_t input_ch = input_dims -> c ;
67
+ const uint16_t kernel_x = filter_dims -> w ;
68
+ const uint16_t kernel_y = filter_dims -> h ;
69
+ const uint16_t output_x = output_dims -> w ;
70
+ const uint16_t output_y = output_dims -> h ;
71
+ const uint16_t output_ch = output_dims -> c ;
72
+
73
+ const uint16_t pad_x = conv_params -> padding .w ;
74
+ const uint16_t pad_y = conv_params -> padding .h ;
75
+ const uint16_t stride_x = conv_params -> stride .w ;
76
+ const uint16_t stride_y = conv_params -> stride .h ;
77
+ const int32_t dilation_x = conv_params -> dilation .w ;
78
+ const int32_t dilation_y = conv_params -> dilation .h ;
79
+ const int32_t out_offset = conv_params -> output_offset ;
80
+ const int32_t out_activation_min = conv_params -> activation .min ;
81
+ const int32_t out_activation_max = conv_params -> activation .max ;
82
+ const int32_t rhs_cols = kernel_x * kernel_y * input_ch ;
83
+ const int32_t input_offset = conv_params -> input_offset ;
84
+
85
+ int32_t * output_mult = quant_params -> multiplier ;
86
+ int32_t * output_shift = quant_params -> shift ;
87
+
88
+ int i_batch ;
89
+ for (i_batch = 0 ; i_batch < input_batches ; i_batch ++ )
86
90
{
87
- const int32_t remainder = rhs_cols % 4 ;
88
- const int32_t aligned_rhs_cols = remainder != 0 ? rhs_cols + 4 - remainder : rhs_cols ;
89
- /**
90
- * Use Im2col to speed up conv2d calculations.
91
- * Use as a ping-pong buffer for unordered elements.
92
- */
93
- int8_t * im2col_buf = (int8_t * )buffer_a + aligned_rhs_cols * 2 ;
94
- int16_t * im2col_buf_start_s16 = buffer_a ;
95
- int8_t * out = output_data ;
96
- int32_t lhs_rows = 0 ;
97
- /* This part implements the im2col function */
98
- for (int i_out_x = 0 ; i_out_x < output_x ; i_out_x ++ )
91
+ const int32_t remainder = rhs_cols % 4 ;
92
+ const int32_t aligned_rhs_cols = remainder != 0 ?
93
+ rhs_cols + 4 - remainder : rhs_cols ;
94
+
95
+ /**
96
+ * Use Im2col to speed up conv2d calculations.
97
+ * Use as a ping-pong buffer for unordered elements.
98
+ */
99
+
100
+ int8_t * im2col_buf = (int8_t * )buffer_a + aligned_rhs_cols * 2 ;
101
+ int16_t * im2col_buf_start_s16 = buffer_a ;
102
+ int8_t * out = output_data ;
103
+ int32_t lhs_rows = 0 ;
104
+
105
+ /* This part implements the im2col function */
106
+
107
+ for (int i_out_x = 0 ; i_out_x < output_x ; i_out_x ++ )
99
108
{
100
- const int32_t base_idx_x = stride_x * i_out_x - pad_x ;
101
- for (int i_out_y = 0 ; i_out_y < output_y ; i_out_y ++ )
109
+ const int32_t base_idx_x = stride_x * i_out_x - pad_x ;
110
+ for (int i_out_y = 0 ; i_out_y < output_y ; i_out_y ++ )
102
111
{
103
- const int32_t base_idx_y = stride_y * i_out_y - pad_y ;
104
- for (int32_t i_ker_x = 0 ; i_ker_x < kernel_x ; i_ker_x ++ )
112
+ const int32_t base_idx_y = stride_y * i_out_y - pad_y ;
113
+ for (int32_t i_ker_x = 0 ; i_ker_x < kernel_x ; i_ker_x ++ )
105
114
{
106
- int32_t k_x = base_idx_x + dilation_x * i_ker_x ;
107
- int32_t k_y = base_idx_y - dilation_y ;
108
- for (int32_t i_ker_y = 0 ; i_ker_y < kernel_y ; i_ker_y ++ )
115
+ int32_t k_x = base_idx_x + dilation_x * i_ker_x ;
116
+ int32_t k_y = base_idx_y - dilation_y ;
117
+ for (int32_t i_ker_y = 0 ; i_ker_y < kernel_y ; i_ker_y ++ )
109
118
{
110
- k_y += dilation_y ;
111
- arm_memcpy_s8 (im2col_buf ,
112
- input_data + (k_y * input_x + k_x ) * input_ch ,
113
- input_ch );
114
- im2col_buf += input_ch ;
119
+ k_y += dilation_y ;
120
+ arm_memcpy_s8 (im2col_buf ,
121
+ input_data + (k_y * input_x + k_x ) * input_ch ,
122
+ input_ch );
123
+ im2col_buf += input_ch ;
115
124
}
116
125
}
117
- lhs_rows ++ ;
118
- /* Extend the input data from int8 to int16, and add offset. */
119
- arm_q7_to_q15_with_offset (im2col_buf - rhs_cols ,
120
- im2col_buf_start_s16 ,
121
- rhs_cols ,
122
- (int16_t )input_offset );
123
- im2col_buf_start_s16 += aligned_rhs_cols ;
124
- if (lhs_rows & 2 )
126
+
127
+ lhs_rows ++ ;
128
+
129
+ /* Extend the input data from int8 to int16, and add offset. */
130
+
131
+ arm_q7_to_q15_with_offset (im2col_buf - rhs_cols ,
132
+ im2col_buf_start_s16 ,
133
+ rhs_cols ,
134
+ (int16_t )input_offset );
135
+ im2col_buf_start_s16 += aligned_rhs_cols ;
136
+ if (lhs_rows & 2 )
125
137
{
126
- out = arm_nn_mat_mult_kernel_s8_s16 (filter_data ,
127
- buffer_a ,
128
- output_ch ,
129
- output_shift ,
130
- output_mult ,
131
- out_offset ,
132
- out_activation_min ,
133
- out_activation_max ,
134
- rhs_cols ,
135
- aligned_rhs_cols ,
136
- bias_data ,
137
- out );
138
- /* counter reset */
139
- im2col_buf_start_s16 = buffer_a ;
140
- im2col_buf = (int8_t * )buffer_a + (aligned_rhs_cols << 1 );
141
- lhs_rows = 0 ;
138
+ out = arm_nn_mat_mult_kernel_s8_s16 (filter_data ,
139
+ buffer_a ,
140
+ output_ch ,
141
+ output_shift ,
142
+ output_mult ,
143
+ out_offset ,
144
+ out_activation_min ,
145
+ out_activation_max ,
146
+ rhs_cols ,
147
+ aligned_rhs_cols ,
148
+ bias_data ,
149
+ out );
150
+
151
+ /* counter reset */
152
+
153
+ im2col_buf_start_s16 = buffer_a ;
154
+ im2col_buf = (int8_t * )buffer_a + (aligned_rhs_cols << 1 );
155
+ lhs_rows = 0 ;
142
156
}
143
157
}
144
158
}
145
- if (lhs_rows != 0 )
159
+
160
+ if (lhs_rows != 0 )
146
161
{
147
- const int8_t * ker_a = filter_data ;
148
- int i ;
149
- for (i = 0 ; i < output_ch ; i ++ )
162
+ const int8_t * ker_a = filter_data ;
163
+ int i ;
164
+
165
+ for (i = 0 ; i < output_ch ; i ++ )
150
166
{
151
- /* Load the accumulator with bias first */
152
- uint16_t col_count = rhs_cols / 8 ;
153
- int32_t sum = 0 ;
154
- const int16_t * ip_as_col = buffer_a ;
155
- int32x4_t res_s32 = vdupq_n_s32 (0 );
156
- if (bias_data )
167
+ /* Load the accumulator with bias first */
168
+
169
+ uint16_t col_count = rhs_cols / 8 ;
170
+ int32_t sum = 0 ;
171
+ const int16_t * ip_as_col = buffer_a ;
172
+ int32x4_t res_s32 = vdupq_n_s32 (0 );
173
+ if (bias_data )
157
174
{
158
- sum = bias_data [i ];
175
+ sum = bias_data [i ];
159
176
}
160
- while (col_count )
177
+
178
+ while (col_count )
161
179
{
162
- int8x8_t filter_s8 = vld1_s8 (ker_a );
163
- int16x8_t input_s16 = vld1q_s16 (ip_as_col );
164
- int16x8_t filter_s16 = vmovl_s8 (filter_s8 );
165
- ker_a += 8 ;
166
- ip_as_col += 8 ;
167
- res_s32 = vmlal_s16 (res_s32 ,
168
- vget_low_s16 (input_s16 ),
169
- vget_low_s16 (filter_s16 ));
170
- res_s32 = vmlal_s16 (res_s32 ,
171
- vget_high_s16 (input_s16 ),
172
- vget_high_s16 (filter_s16 ));
173
- col_count -- ;
180
+ int8x8_t filter_s8 = vld1_s8 (ker_a );
181
+ int16x8_t input_s16 = vld1q_s16 (ip_as_col );
182
+ int16x8_t filter_s16 = vmovl_s8 (filter_s8 );
183
+ ker_a += 8 ;
184
+ ip_as_col += 8 ;
185
+ res_s32 = vmlal_s16 (res_s32 ,
186
+ vget_low_s16 (input_s16 ),
187
+ vget_low_s16 (filter_s16 ));
188
+ res_s32 = vmlal_s16 (res_s32 ,
189
+ vget_high_s16 (input_s16 ),
190
+ vget_high_s16 (filter_s16 ));
191
+ col_count -- ;
174
192
}
175
- sum += vgetq_lane_s32 (res_s32 , 0 );
176
- sum += vgetq_lane_s32 (res_s32 , 1 );
177
- sum += vgetq_lane_s32 (res_s32 , 2 );
178
- sum += vgetq_lane_s32 (res_s32 , 3 );
179
- col_count = rhs_cols % 8 ;
180
- while (col_count )
193
+
194
+ sum += vgetq_lane_s32 (res_s32 , 0 );
195
+ sum += vgetq_lane_s32 (res_s32 , 1 );
196
+ sum += vgetq_lane_s32 (res_s32 , 2 );
197
+ sum += vgetq_lane_s32 (res_s32 , 3 );
198
+ col_count = rhs_cols % 8 ;
199
+ while (col_count )
181
200
{
182
- int8_t ker_a1 = * ker_a ++ ;
183
- int16_t ip_b1 = * ip_as_col ++ ;
184
- sum += ker_a1 * ip_b1 ;
185
- col_count -- ;
201
+ int8_t ker_a1 = * ker_a ++ ;
202
+ int16_t ip_b1 = * ip_as_col ++ ;
203
+ sum += ker_a1 * ip_b1 ;
204
+ col_count -- ;
186
205
}
187
- sum = arm_nn_requantize (sum , output_mult [i ], output_shift [i ]);
188
- sum += out_offset ;
189
- sum = MAX (sum , out_activation_min );
190
- sum = MIN (sum , out_activation_max );
191
- * out ++ = (int8_t )sum ;
206
+
207
+ sum = arm_nn_requantize (sum ,
208
+ output_mult [i ], output_shift [i ]);
209
+ sum += out_offset ;
210
+ sum = MAX (sum , out_activation_min );
211
+ sum = MIN (sum , out_activation_max );
212
+ * out ++ = (int8_t )sum ;
192
213
}
193
214
}
194
- /* Advance to the next batch */
195
- input_data += (input_x * input_y * input_ch );
196
- output_data += (output_x * output_y * output_ch );
215
+
216
+ /* Advance to the next batch */
217
+
218
+ input_data += (input_x * input_y * input_ch );
219
+ output_data += (output_x * output_y * output_ch );
197
220
}
198
- /* Return to application */
199
- return ARM_CMSIS_NN_SUCCESS ;
221
+
222
+ /* Return to application */
223
+
224
+ return ARM_CMSIS_NN_SUCCESS ;
200
225
}
201
226
202
- /**
203
- * @} end of NNConv group
204
- */
0 commit comments