Skip to content

Commit 2593728

Browse files
jihandongxiaoxiang781216
authored andcommitted
ml: follow nxstyle
Signed-off-by: jihandong <[email protected]>
1 parent 7d87768 commit 2593728

File tree

4 files changed

+644
-553
lines changed

4 files changed

+644
-553
lines changed

mlearning/tflite-micro/operators/neon/arm_convolve_s8.c

Lines changed: 176 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
/*
2-
* SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <[email protected]>
1+
/****************************************************************************
2+
* apps/mlearning/tflite-micro/operators/neon/arm_convolve_s8.c
3+
*
4+
* SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or
5+
* its affiliates <[email protected]>
36
*
47
* SPDX-License-Identifier: Apache-2.0
58
*
@@ -14,191 +17,210 @@
1417
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1518
* See the License for the specific language governing permissions and
1619
* limitations under the License.
17-
*/
20+
****************************************************************************/
21+
22+
/****************************************************************************
23+
* Included Files
24+
****************************************************************************/
1825

1926
#include <arm_neon.h>
2027
#include "arm_nnfunctions.h"
2128
#include "arm_nnsupportfunctions.h"
2229

23-
/**
24-
* @ingroup Public
25-
*/
26-
27-
/**
28-
* @addtogroup NNConv
29-
* @{
30-
*/
30+
/****************************************************************************
31+
* Public Functions
32+
****************************************************************************/
3133

32-
/*
33-
* Basic s8 convolution function.
34-
*
35-
* Refer header file for details. Optimal use case for the DSP/MVE implementation is when input and output channels
36-
* are multiples of 4 or atleast greater than 4.
34+
/* Basic s8 convolution function.
3735
*
36+
* Refer header file for details. Optimal use case for the DSP/MVE
37+
* implementation is when input and output channels are multiples of 4 or
38+
* atleast greater than 4.
3839
*/
39-
arm_cmsis_nn_status arm_convolve_s8(const cmsis_nn_context *ctx,
40-
const cmsis_nn_conv_params *conv_params,
41-
const cmsis_nn_per_channel_quant_params *quant_params,
42-
const cmsis_nn_dims *input_dims,
43-
const int8_t *input_data,
44-
const cmsis_nn_dims *filter_dims,
45-
const int8_t *filter_data,
46-
const cmsis_nn_dims *bias_dims,
47-
const int32_t *bias_data,
48-
const cmsis_nn_dims *output_dims,
49-
int8_t *output_data)
40+
41+
arm_cmsis_nn_status
42+
arm_convolve_s8(const cmsis_nn_context *ctx,
43+
const cmsis_nn_conv_params *conv_params,
44+
const cmsis_nn_per_channel_quant_params *quant_params,
45+
const cmsis_nn_dims *input_dims,
46+
const int8_t *input_data,
47+
const cmsis_nn_dims *filter_dims,
48+
const int8_t *filter_data,
49+
const cmsis_nn_dims *bias_dims,
50+
const int32_t *bias_data,
51+
const cmsis_nn_dims *output_dims,
52+
int8_t *output_data)
5053
{
51-
(void)bias_dims;
54+
(void)bias_dims;
5255

53-
if (ctx->buf == NULL)
56+
if (ctx->buf == NULL)
5457
{
55-
return ARM_CMSIS_NN_ARG_ERROR;
58+
return ARM_CMSIS_NN_ARG_ERROR;
5659
}
57-
int16_t *buffer_a = (int16_t *)ctx->buf;
58-
59-
const int32_t input_batches = input_dims->n;
60-
const uint16_t input_x = input_dims->w;
61-
const uint16_t input_y = input_dims->h;
62-
const uint16_t input_ch = input_dims->c;
63-
const uint16_t kernel_x = filter_dims->w;
64-
const uint16_t kernel_y = filter_dims->h;
65-
const uint16_t output_x = output_dims->w;
66-
const uint16_t output_y = output_dims->h;
67-
const uint16_t output_ch = output_dims->c;
68-
69-
const uint16_t pad_x = conv_params->padding.w;
70-
const uint16_t pad_y = conv_params->padding.h;
71-
const uint16_t stride_x = conv_params->stride.w;
72-
const uint16_t stride_y = conv_params->stride.h;
73-
const int32_t dilation_x = conv_params->dilation.w;
74-
const int32_t dilation_y = conv_params->dilation.h;
75-
const int32_t out_offset = conv_params->output_offset;
76-
const int32_t out_activation_min = conv_params->activation.min;
77-
const int32_t out_activation_max = conv_params->activation.max;
78-
const int32_t rhs_cols = kernel_x * kernel_y * input_ch;
79-
const int32_t input_offset = conv_params->input_offset;
80-
81-
int32_t *output_mult = quant_params->multiplier;
82-
int32_t *output_shift = quant_params->shift;
83-
84-
int i_batch;
85-
for (i_batch = 0; i_batch < input_batches; i_batch++)
60+
61+
int16_t *buffer_a = (int16_t *)ctx->buf;
62+
63+
const int32_t input_batches = input_dims->n;
64+
const uint16_t input_x = input_dims->w;
65+
const uint16_t input_y = input_dims->h;
66+
const uint16_t input_ch = input_dims->c;
67+
const uint16_t kernel_x = filter_dims->w;
68+
const uint16_t kernel_y = filter_dims->h;
69+
const uint16_t output_x = output_dims->w;
70+
const uint16_t output_y = output_dims->h;
71+
const uint16_t output_ch = output_dims->c;
72+
73+
const uint16_t pad_x = conv_params->padding.w;
74+
const uint16_t pad_y = conv_params->padding.h;
75+
const uint16_t stride_x = conv_params->stride.w;
76+
const uint16_t stride_y = conv_params->stride.h;
77+
const int32_t dilation_x = conv_params->dilation.w;
78+
const int32_t dilation_y = conv_params->dilation.h;
79+
const int32_t out_offset = conv_params->output_offset;
80+
const int32_t out_activation_min = conv_params->activation.min;
81+
const int32_t out_activation_max = conv_params->activation.max;
82+
const int32_t rhs_cols = kernel_x * kernel_y * input_ch;
83+
const int32_t input_offset = conv_params->input_offset;
84+
85+
int32_t *output_mult = quant_params->multiplier;
86+
int32_t *output_shift = quant_params->shift;
87+
88+
int i_batch;
89+
for (i_batch = 0; i_batch < input_batches; i_batch++)
8690
{
87-
const int32_t remainder = rhs_cols % 4;
88-
const int32_t aligned_rhs_cols = remainder != 0 ? rhs_cols + 4 - remainder : rhs_cols;
89-
/**
90-
* Use Im2col to speed up conv2d calculations.
91-
* Use as a ping-pong buffer for unordered elements.
92-
*/
93-
int8_t *im2col_buf = (int8_t *)buffer_a + aligned_rhs_cols * 2;
94-
int16_t *im2col_buf_start_s16 = buffer_a;
95-
int8_t *out = output_data;
96-
int32_t lhs_rows = 0;
97-
/* This part implements the im2col function */
98-
for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
91+
const int32_t remainder = rhs_cols % 4;
92+
const int32_t aligned_rhs_cols = remainder != 0 ?
93+
rhs_cols + 4 - remainder : rhs_cols;
94+
95+
/**
96+
* Use Im2col to speed up conv2d calculations.
97+
* Use as a ping-pong buffer for unordered elements.
98+
*/
99+
100+
int8_t *im2col_buf = (int8_t *)buffer_a + aligned_rhs_cols * 2;
101+
int16_t *im2col_buf_start_s16 = buffer_a;
102+
int8_t *out = output_data;
103+
int32_t lhs_rows = 0;
104+
105+
/* This part implements the im2col function */
106+
107+
for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
99108
{
100-
const int32_t base_idx_x = stride_x * i_out_x - pad_x;
101-
for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
109+
const int32_t base_idx_x = stride_x * i_out_x - pad_x;
110+
for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
102111
{
103-
const int32_t base_idx_y = stride_y * i_out_y - pad_y;
104-
for (int32_t i_ker_x = 0; i_ker_x < kernel_x; i_ker_x++)
112+
const int32_t base_idx_y = stride_y * i_out_y - pad_y;
113+
for (int32_t i_ker_x = 0; i_ker_x < kernel_x; i_ker_x++)
105114
{
106-
int32_t k_x = base_idx_x + dilation_x * i_ker_x;
107-
int32_t k_y = base_idx_y - dilation_y;
108-
for (int32_t i_ker_y = 0; i_ker_y < kernel_y; i_ker_y++)
115+
int32_t k_x = base_idx_x + dilation_x * i_ker_x;
116+
int32_t k_y = base_idx_y - dilation_y;
117+
for (int32_t i_ker_y = 0; i_ker_y < kernel_y; i_ker_y++)
109118
{
110-
k_y += dilation_y;
111-
arm_memcpy_s8(im2col_buf,
112-
input_data + (k_y * input_x + k_x) * input_ch,
113-
input_ch);
114-
im2col_buf += input_ch;
119+
k_y += dilation_y;
120+
arm_memcpy_s8(im2col_buf,
121+
input_data + (k_y * input_x + k_x) * input_ch,
122+
input_ch);
123+
im2col_buf += input_ch;
115124
}
116125
}
117-
lhs_rows++;
118-
/* Extend the input data from int8 to int16, and add offset. */
119-
arm_q7_to_q15_with_offset(im2col_buf - rhs_cols,
120-
im2col_buf_start_s16,
121-
rhs_cols,
122-
(int16_t)input_offset);
123-
im2col_buf_start_s16 += aligned_rhs_cols;
124-
if (lhs_rows & 2)
126+
127+
lhs_rows++;
128+
129+
/* Extend the input data from int8 to int16, and add offset. */
130+
131+
arm_q7_to_q15_with_offset(im2col_buf - rhs_cols,
132+
im2col_buf_start_s16,
133+
rhs_cols,
134+
(int16_t)input_offset);
135+
im2col_buf_start_s16 += aligned_rhs_cols;
136+
if (lhs_rows & 2)
125137
{
126-
out = arm_nn_mat_mult_kernel_s8_s16(filter_data,
127-
buffer_a,
128-
output_ch,
129-
output_shift,
130-
output_mult,
131-
out_offset,
132-
out_activation_min,
133-
out_activation_max,
134-
rhs_cols,
135-
aligned_rhs_cols,
136-
bias_data,
137-
out);
138-
/* counter reset */
139-
im2col_buf_start_s16 = buffer_a;
140-
im2col_buf = (int8_t *)buffer_a + (aligned_rhs_cols << 1);
141-
lhs_rows = 0;
138+
out = arm_nn_mat_mult_kernel_s8_s16(filter_data,
139+
buffer_a,
140+
output_ch,
141+
output_shift,
142+
output_mult,
143+
out_offset,
144+
out_activation_min,
145+
out_activation_max,
146+
rhs_cols,
147+
aligned_rhs_cols,
148+
bias_data,
149+
out);
150+
151+
/* counter reset */
152+
153+
im2col_buf_start_s16 = buffer_a;
154+
im2col_buf = (int8_t *)buffer_a + (aligned_rhs_cols << 1);
155+
lhs_rows = 0;
142156
}
143157
}
144158
}
145-
if (lhs_rows != 0)
159+
160+
if (lhs_rows != 0)
146161
{
147-
const int8_t *ker_a = filter_data;
148-
int i;
149-
for (i = 0; i < output_ch; i++)
162+
const int8_t *ker_a = filter_data;
163+
int i;
164+
165+
for (i = 0; i < output_ch; i++)
150166
{
151-
/* Load the accumulator with bias first */
152-
uint16_t col_count = rhs_cols / 8;
153-
int32_t sum = 0;
154-
const int16_t *ip_as_col = buffer_a;
155-
int32x4_t res_s32 = vdupq_n_s32(0);
156-
if (bias_data)
167+
/* Load the accumulator with bias first */
168+
169+
uint16_t col_count = rhs_cols / 8;
170+
int32_t sum = 0;
171+
const int16_t *ip_as_col = buffer_a;
172+
int32x4_t res_s32 = vdupq_n_s32(0);
173+
if (bias_data)
157174
{
158-
sum = bias_data[i];
175+
sum = bias_data[i];
159176
}
160-
while (col_count)
177+
178+
while (col_count)
161179
{
162-
int8x8_t filter_s8 = vld1_s8(ker_a);
163-
int16x8_t input_s16 = vld1q_s16(ip_as_col);
164-
int16x8_t filter_s16 = vmovl_s8(filter_s8);
165-
ker_a += 8;
166-
ip_as_col += 8;
167-
res_s32 = vmlal_s16(res_s32,
168-
vget_low_s16(input_s16),
169-
vget_low_s16(filter_s16));
170-
res_s32 = vmlal_s16(res_s32,
171-
vget_high_s16(input_s16),
172-
vget_high_s16(filter_s16));
173-
col_count --;
180+
int8x8_t filter_s8 = vld1_s8(ker_a);
181+
int16x8_t input_s16 = vld1q_s16(ip_as_col);
182+
int16x8_t filter_s16 = vmovl_s8(filter_s8);
183+
ker_a += 8;
184+
ip_as_col += 8;
185+
res_s32 = vmlal_s16(res_s32,
186+
vget_low_s16(input_s16),
187+
vget_low_s16(filter_s16));
188+
res_s32 = vmlal_s16(res_s32,
189+
vget_high_s16(input_s16),
190+
vget_high_s16(filter_s16));
191+
col_count--;
174192
}
175-
sum += vgetq_lane_s32(res_s32, 0);
176-
sum += vgetq_lane_s32(res_s32, 1);
177-
sum += vgetq_lane_s32(res_s32, 2);
178-
sum += vgetq_lane_s32(res_s32, 3);
179-
col_count = rhs_cols % 8;
180-
while (col_count)
193+
194+
sum += vgetq_lane_s32(res_s32, 0);
195+
sum += vgetq_lane_s32(res_s32, 1);
196+
sum += vgetq_lane_s32(res_s32, 2);
197+
sum += vgetq_lane_s32(res_s32, 3);
198+
col_count = rhs_cols % 8;
199+
while (col_count)
181200
{
182-
int8_t ker_a1 = *ker_a++;
183-
int16_t ip_b1 = *ip_as_col++;
184-
sum += ker_a1 * ip_b1;
185-
col_count--;
201+
int8_t ker_a1 = *ker_a++;
202+
int16_t ip_b1 = *ip_as_col++;
203+
sum += ker_a1 * ip_b1;
204+
col_count--;
186205
}
187-
sum = arm_nn_requantize(sum, output_mult[i], output_shift[i]);
188-
sum += out_offset;
189-
sum = MAX(sum, out_activation_min);
190-
sum = MIN(sum, out_activation_max);
191-
*out++ = (int8_t)sum;
206+
207+
sum = arm_nn_requantize(sum,
208+
output_mult[i], output_shift[i]);
209+
sum += out_offset;
210+
sum = MAX(sum, out_activation_min);
211+
sum = MIN(sum, out_activation_max);
212+
*out++ = (int8_t)sum;
192213
}
193214
}
194-
/* Advance to the next batch */
195-
input_data += (input_x * input_y * input_ch);
196-
output_data += (output_x * output_y * output_ch);
215+
216+
/* Advance to the next batch */
217+
218+
input_data += (input_x * input_y * input_ch);
219+
output_data += (output_x * output_y * output_ch);
197220
}
198-
/* Return to application */
199-
return ARM_CMSIS_NN_SUCCESS;
221+
222+
/* Return to application */
223+
224+
return ARM_CMSIS_NN_SUCCESS;
200225
}
201226

202-
/**
203-
* @} end of NNConv group
204-
*/

0 commit comments

Comments
 (0)