Skip to content

Commit

Permalink
opt++
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Apr 29, 2024
1 parent 3b379d7 commit 2c1a394
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 66 deletions.
60 changes: 26 additions & 34 deletions src/layer/arm/gru_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -537,17 +537,15 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
int32x4_t _sum2 = vdupq_n_s32(0);
for (; i + 7 < size; i += 8)
{
int32x2_t _xi01 = vreinterpret_s32_s8(vld1_s8(x + i));
int8x16_t _xi0 = vreinterpretq_s8_s32(vdupq_lane_s32(_xi01, 0));
int8x16_t _xi1 = vreinterpretq_s8_s32(vdupq_lane_s32(_xi01, 1));
int8x8_t _xi = vld1_s8(x + i);
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
int8x16_t _w2 = vld1q_s8(kptr + 32);
int8x16_t _w3 = vld1q_s8(kptr + 48);
_gru_Rx0 = vdotq_s32(_gru_Rx0, _w0, _xi0);
_gru_Ux0 = vdotq_s32(_gru_Ux0, _w1, _xi0);
_sum1 = vdotq_s32(_sum1, _w2, _xi1);
_sum2 = vdotq_s32(_sum2, _w3, _xi1);
_gru_Rx0 = vdotq_lane_s32(_gru_Rx0, _w0, _xi, 0);
_gru_Ux0 = vdotq_lane_s32(_gru_Ux0, _w1, _xi, 0);
_sum1 = vdotq_lane_s32(_sum1, _w2, _xi, 1);
_sum2 = vdotq_lane_s32(_sum2, _w3, _xi, 1);

kptr += 64;
}
Expand All @@ -557,11 +555,11 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
for (; i + 3 < size; i += 4)
{
#if __ARM_FEATURE_DOTPROD
int8x16_t _xi = vreinterpretq_s8_s32(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(x + i)), 0));
int8x8_t _xi = vld1_s8(x + i);
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
_gru_Rx0 = vdotq_s32(_gru_Rx0, _w0, _xi);
_gru_Ux0 = vdotq_s32(_gru_Ux0, _w1, _xi);
_gru_Rx0 = vdotq_lane_s32(_gru_Rx0, _w0, _xi, 0);
_gru_Ux0 = vdotq_lane_s32(_gru_Ux0, _w1, _xi, 0);
#else
int16x4_t _xi01 = vreinterpret_s16_s8(vld1_s8(x + i));
int8x8_t _xi0 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 0));
Expand Down Expand Up @@ -613,17 +611,15 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
_sum2 = vdupq_n_s32(0);
for (; i + 7 < num_output; i += 8)
{
int32x2_t _h_cont01 = vreinterpret_s32_s8(vld1_s8(hs + i));
int8x16_t _h_cont0 = vreinterpretq_s8_s32(vdupq_lane_s32(_h_cont01, 0));
int8x16_t _h_cont1 = vreinterpretq_s8_s32(vdupq_lane_s32(_h_cont01, 1));
int8x8_t _h_cont = vld1_s8(hs + i);
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
int8x16_t _w2 = vld1q_s8(kptr + 32);
int8x16_t _w3 = vld1q_s8(kptr + 48);
_gru_Rh0 = vdotq_s32(_gru_Rh0, _w0, _h_cont0);
_gru_Uh0 = vdotq_s32(_gru_Uh0, _w1, _h_cont0);
_sum1 = vdotq_s32(_sum1, _w2, _h_cont1);
_sum2 = vdotq_s32(_sum2, _w3, _h_cont1);
_gru_Rh0 = vdotq_lane_s32(_gru_Rh0, _w0, _h_cont, 0);
_gru_Uh0 = vdotq_lane_s32(_gru_Uh0, _w1, _h_cont, 0);
_sum1 = vdotq_lane_s32(_sum1, _w2, _h_cont, 1);
_sum2 = vdotq_lane_s32(_sum2, _w3, _h_cont, 1);

kptr += 64;
}
Expand All @@ -633,11 +629,11 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
for (; i + 3 < num_output; i += 4)
{
#if __ARM_FEATURE_DOTPROD
int8x16_t _h_cont = vreinterpretq_s8_s32(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(hs + i)), 0));
int8x8_t _h_cont = vld1_s8(hs + i);
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
_gru_Rh0 = vdotq_s32(_gru_Rh0, _w0, _h_cont);
_gru_Uh0 = vdotq_s32(_gru_Uh0, _w1, _h_cont);
_gru_Rh0 = vdotq_lane_s32(_gru_Rh0, _w0, _h_cont, 0);
_gru_Uh0 = vdotq_lane_s32(_gru_Uh0, _w1, _h_cont, 0);
#else
int16x4_t _h_cont01 = vreinterpret_s16_s8(vld1_s8(hs + i));
int8x8_t _h_cont0 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 0));
Expand Down Expand Up @@ -712,13 +708,11 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
_sum1 = vdupq_n_s32(0);
for (; i + 7 < num_output; i += 8)
{
int32x2_t _h_cont01 = vreinterpret_s32_s8(vld1_s8(hs + i));
int8x16_t _h_cont0 = vreinterpretq_s8_s32(vdupq_lane_s32(_h_cont01, 0));
int8x16_t _h_cont1 = vreinterpretq_s8_s32(vdupq_lane_s32(_h_cont01, 1));
int8x8_t _h_cont = vld1_s8(hs + i);
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
_gru_Nh0 = vdotq_s32(_gru_Nh0, _w0, _h_cont0);
_sum1 = vdotq_s32(_sum1, _w1, _h_cont1);
_gru_Nh0 = vdotq_lane_s32(_gru_Nh0, _w0, _h_cont, 0);
_sum1 = vdotq_lane_s32(_sum1, _w1, _h_cont, 1);

kptr += 32;
}
Expand All @@ -727,9 +721,9 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
for (; i + 3 < num_output; i += 4)
{
#if __ARM_FEATURE_DOTPROD
int8x16_t _h_cont = vreinterpretq_s8_s32(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(hs + i)), 0));
int8x8_t _h_cont = vld1_s8(hs + i);
int8x16_t _w = vld1q_s8(kptr);
_gru_Nh0 = vdotq_s32(_gru_Nh0, _w, _h_cont);
_gru_Nh0 = vdotq_lane_s32(_gru_Nh0, _w, _h_cont, 0);
#else
int16x4_t _h_cont01 = vreinterpret_s16_s8(vld1_s8(hs + i));
int8x8_t _h_cont0 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 0));
Expand Down Expand Up @@ -770,13 +764,11 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
_sum1 = vdupq_n_s32(0);
for (; i + 7 < size; i += 8)
{
int32x2_t _xi01 = vreinterpret_s32_s8(vld1_s8(x + i));
int8x16_t _xi0 = vreinterpretq_s8_s32(vdupq_lane_s32(_xi01, 0));
int8x16_t _xi1 = vreinterpretq_s8_s32(vdupq_lane_s32(_xi01, 1));
int8x8_t _xi = vld1_s8(x + i);
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
_gru_Nx0 = vdotq_s32(_gru_Nx0, _w0, _xi0);
_sum1 = vdotq_s32(_sum1, _w1, _xi1);
_gru_Nx0 = vdotq_lane_s32(_gru_Nx0, _w0, _xi, 0);
_sum1 = vdotq_lane_s32(_sum1, _w1, _xi, 1);

kptr += 32;
}
Expand All @@ -785,9 +777,9 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
for (; i + 3 < size; i += 4)
{
#if __ARM_FEATURE_DOTPROD
int8x16_t _xi = vreinterpretq_s8_s32(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(x + i)), 0));
int8x8_t _xi = vld1_s8(x + i);
int8x16_t _w = vld1q_s8(kptr);
_gru_Nx0 = vdotq_s32(_gru_Nx0, _w, _xi);
_gru_Nx0 = vdotq_lane_s32(_gru_Nx0, _w, _xi, 0);
#else
int16x4_t _xi01 = vreinterpret_s16_s8(vld1_s8(x + i));
int8x8_t _xi0 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 0));
Expand Down
52 changes: 20 additions & 32 deletions src/layer/arm/lstm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,31 +268,25 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
int32x4_t _sum3 = vdupq_n_s32(0);
for (; i + 15 < size; i += 16)
{
int32x4_t _xi01 = vreinterpretq_s32_s8(vld1q_s8(x + i));
int8x16_t _xi0 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 0));
int8x16_t _xi1 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 1));
int8x16_t _xi2 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 2));
int8x16_t _xi3 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 3));
int8x16_t _xi = vld1q_s8(x + i);
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
int8x16_t _w2 = vld1q_s8(kptr + 32);
int8x16_t _w3 = vld1q_s8(kptr + 48);
_lstm_IFOGx0 = vdotq_s32(_lstm_IFOGx0, _w0, _xi0);
_sum1 = vdotq_s32(_sum1, _w1, _xi1);
_sum2 = vdotq_s32(_sum2, _w2, _xi2);
_sum3 = vdotq_s32(_sum3, _w3, _xi3);
_lstm_IFOGx0 = vdotq_laneq_s32(_lstm_IFOGx0, _w0, _xi, 0);
_sum1 = vdotq_laneq_s32(_sum1, _w1, _xi, 1);
_sum2 = vdotq_laneq_s32(_sum2, _w2, _xi, 2);
_sum3 = vdotq_laneq_s32(_sum3, _w3, _xi, 3);

kptr += 64;
}
for (; i + 7 < size; i += 8)
{
int32x2_t _xi01 = vreinterpret_s32_s8(vld1_s8(x + i));
int8x16_t _xi0 = vreinterpretq_s8_s32(vdupq_lane_s32(_xi01, 0));
int8x16_t _xi1 = vreinterpretq_s8_s32(vdupq_lane_s32(_xi01, 1));
int8x8_t _xi = vld1_s8(x + i);
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
_lstm_IFOGx0 = vdotq_s32(_lstm_IFOGx0, _w0, _xi0);
_sum1 = vdotq_s32(_sum1, _w1, _xi1);
_lstm_IFOGx0 = vdotq_lane_s32(_lstm_IFOGx0, _w0, _xi, 0);
_sum1 = vdotq_lane_s32(_sum1, _w1, _xi, 1);

kptr += 32;
}
Expand All @@ -303,9 +297,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
for (; i + 3 < size; i += 4)
{
#if __ARM_FEATURE_DOTPROD
int8x16_t _xi = vreinterpretq_s8_s32(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(x + i)), 0));
int8x8_t _xi = vld1_s8(x + i);
int8x16_t _w = vld1q_s8(kptr);
_lstm_IFOGx0 = vdotq_s32(_lstm_IFOGx0, _w, _xi);
_lstm_IFOGx0 = vdotq_lane_s32(_lstm_IFOGx0, _w, _xi, 0);
#else
int16x4_t _xi01 = vreinterpret_s16_s8(vld1_s8(x + i));
int8x8_t _xi0 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 0));
Expand Down Expand Up @@ -348,31 +342,25 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
_sum3 = vdupq_n_s32(0);
for (; i + 15 < num_output; i += 16)
{
int32x4_t _h_cont01 = vreinterpretq_s32_s8(vld1q_s8(hs + i));
int8x16_t _h_cont0 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 0));
int8x16_t _h_cont1 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 1));
int8x16_t _h_cont2 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 2));
int8x16_t _h_cont3 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 3));
int8x16_t _h_cont = vld1q_s8(hs + i);
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
int8x16_t _w2 = vld1q_s8(kptr + 32);
int8x16_t _w3 = vld1q_s8(kptr + 48);
_lstm_IFOGh0 = vdotq_s32(_lstm_IFOGh0, _w0, _h_cont0);
_sum1 = vdotq_s32(_sum1, _w1, _h_cont1);
_sum2 = vdotq_s32(_sum2, _w2, _h_cont2);
_sum3 = vdotq_s32(_sum3, _w3, _h_cont3);
_lstm_IFOGh0 = vdotq_laneq_s32(_lstm_IFOGh0, _w0, _h_cont, 0);
_sum1 = vdotq_laneq_s32(_sum1, _w1, _h_cont, 1);
_sum2 = vdotq_laneq_s32(_sum2, _w2, _h_cont, 2);
_sum3 = vdotq_laneq_s32(_sum3, _w3, _h_cont, 3);

kptr += 64;
}
for (; i + 7 < num_output; i += 8)
{
int32x2_t _h_cont01 = vreinterpret_s32_s8(vld1_s8(hs + i));
int8x16_t _h_cont0 = vreinterpretq_s8_s32(vdupq_lane_s32(_h_cont01, 0));
int8x16_t _h_cont1 = vreinterpretq_s8_s32(vdupq_lane_s32(_h_cont01, 1));
int8x8_t _h_cont = vld1_s8(hs + i);
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
_lstm_IFOGh0 = vdotq_s32(_lstm_IFOGh0, _w0, _h_cont0);
_sum1 = vdotq_s32(_sum1, _w1, _h_cont1);
_lstm_IFOGh0 = vdotq_lane_s32(_lstm_IFOGh0, _w0, _h_cont, 0);
_sum1 = vdotq_lane_s32(_sum1, _w1, _h_cont, 1);

kptr += 32;
}
Expand All @@ -383,9 +371,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
for (; i + 3 < num_output; i += 4)
{
#if __ARM_FEATURE_DOTPROD
int8x16_t _h_cont = vreinterpretq_s8_s32(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(hs + i)), 0));
int8x8_t _h_cont = vld1_s8(hs + i);
int8x16_t _w = vld1q_s8(kptr);
_lstm_IFOGh0 = vdotq_s32(_lstm_IFOGh0, _w, _h_cont);
_lstm_IFOGh0 = vdotq_lane_s32(_lstm_IFOGh0, _w, _h_cont, 0);
#else
int16x4_t _h_cont01 = vreinterpret_s16_s8(vld1_s8(hs + i));
int8x8_t _h_cont0 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 0));
Expand Down

0 comments on commit 2c1a394

Please sign in to comment.