From 06e27206e5893a0632851a0cdc2e26b687455e5c Mon Sep 17 00:00:00 2001 From: Aelphy Date: Mon, 19 Jun 2023 14:59:40 +0200 Subject: [PATCH] [xtensa] widen ops, convert, division, gather_load improvements --- Makefile | 2 + src/CodeGen_Xtensa.cpp | 7 ++ src/CodeGen_Xtensa_vectors.template.cpp | 141 ++++++++++++++++++++++-- src/XtensaOptimize.cpp | 62 ++++++++++- 4 files changed, 195 insertions(+), 17 deletions(-) diff --git a/Makefile b/Makefile index 957369a057f6..7d0b24339892 100644 --- a/Makefile +++ b/Makefile @@ -2484,6 +2484,8 @@ XTENSA_RUNTIME_SRC=$(ROOT_DIR)/src/runtime/alignment_128.cpp \ $(ROOT_DIR)/src/runtime/to_string.cpp \ $(ROOT_DIR)/src/runtime/posix_print.cpp \ $(ROOT_DIR)/src/runtime/posix_io.cpp \ + $(ROOT_DIR)/src/runtime/posix_aligned_alloc.cpp \ + $(ROOT_DIR)/src/runtime/posix_allocator.cpp \ $(ROOT_DIR)/src/runtime/xtensa_dma.cpp \ XTENSA_RUNTIME_OBJS=$(patsubst $(ROOT_DIR)/src/runtime/%,$(BIN_DIR)/%,$(patsubst %.cpp,%.o,$(XTENSA_RUNTIME_SRC))) diff --git a/src/CodeGen_Xtensa.cpp b/src/CodeGen_Xtensa.cpp index 91df870bade0..b2ba731ef24d 100644 --- a/src/CodeGen_Xtensa.cpp +++ b/src/CodeGen_Xtensa.cpp @@ -411,6 +411,9 @@ string CodeGen_Xtensa::print_xtensa_call(const Call *op) { rhs << "IVP_ABSSUBUNX16U(" << args[0] + ", " + args[1] + ")"; } return rhs.str(); + } else if (op->name == "halide_xtensa_absd_u8") { + rhs << "IVP_ABSSUBU2NX8(" << args[0] + ", " + args[1] + ")"; + return rhs.str(); } else if (op->name == "halide_xtensa_narrow_i48_with_shift_u16") { rhs << "xb_vecNx16_rtor_xb_vecNx16U(IVP_PACKVRNRNX48(" << args[0] + ", " + args[1] + "))"; return rhs.str(); @@ -465,6 +468,10 @@ void CodeGen_Xtensa::visit(const Div *op) { ostringstream rhs; rhs << "IVP_DIVN_2XF32(" << print_expr(op->a) << ", " << print_expr(op->b) << ")"; print_assignment(op->type, rhs.str()); + } else if (is_native_xtensa_vector(op->type)) { + string sa = print_expr(op->a); + string sb = print_expr(op->b); + print_assignment(op->type, "halide_xtensa_div32(" + sa + ", " + sb + ")"); } else { string sa = print_expr(op->a); string sb = print_expr(op->b); diff --git a/src/CodeGen_Xtensa_vectors.template.cpp b/src/CodeGen_Xtensa_vectors.template.cpp index d62e20a0ca26..ce6706012ff4 100644 --- a/src/CodeGen_Xtensa_vectors.template.cpp +++ b/src/CodeGen_Xtensa_vectors.template.cpp @@ -2212,8 +2212,9 @@ convert(const native_vector_u16_x2 &src) template<> HALIDE_ALWAYS_INLINE native_vector_u8 convert(const native_vector_i16_x2 &src) { - xb_vec2Nx24 wide = IVP_CVT24S2NX16(src.native_vector[1], src.native_vector[0]); - return xb_vec2Nx8_rtor_xb_vec2Nx8U(IVP_PACKL2NX24(wide)); + return IVP_SEL2NX8UI(IVP_MOV2NX8U_FROMNX16(src.native_vector[1]), + IVP_MOV2NX8U_FROMNX16(src.native_vector[0]), + IVP_SELI_8B_EXTRACT_1_OF_2_OFF_0); } template<> @@ -2367,12 +2368,12 @@ HALIDE_ALWAYS_INLINE native_vector_i32_x4 convert HALIDE_ALWAYS_INLINE native_vector_i32_x2 convert(const native_vector_i16 &src) { - const native_vector_i32 m = native_vector_i32(1U << (16 - 1)); - native_vector_i32 x1 = IVP_MOVN_2X32_FROMNX16( - IVP_SELNX16I(native_vector_i16(0), src, IVP_SELI_16B_INTERLEAVE_1_LO)); - native_vector_i32 x2 = IVP_MOVN_2X32_FROMNX16( - IVP_SELNX16I(native_vector_i16(0), src, IVP_SELI_16B_INTERLEAVE_1_HI)); - return native_vector_i32_x2(native_vector_i32_x2::from_native_vector, (x1 ^ m) - m, (x2 ^ m) - m); + native_vector_i16 sign_val = src >> 15; + return native_vector_i32_x2(native_vector_i32_x2::from_native_vector, + IVP_MOVN_2X32_FROMNX16( + IVP_SELNX16UI(sign_val, src, IVP_SELI_16B_INTERLEAVE_1_LO)), + IVP_MOVN_2X32_FROMNX16( + IVP_SELNX16UI(sign_val, src, IVP_SELI_16B_INTERLEAVE_1_HI))); } template<> @@ -2717,13 +2718,11 @@ HALIDE_ALWAYS_INLINE native_vector_u8 halide_xtensa_convert_concat_i16_to_u8(con } HALIDE_ALWAYS_INLINE native_vector_i8 halide_xtensa_convert_concat_u16_to_i8(const native_vector_u16 &a, const native_vector_u16 &b) { - xb_vec2Nx24 wide = IVP_CVT24U2NX16(xb_vecNx16U_rtor_xb_vecNx16(b), xb_vecNx16U_rtor_xb_vecNx16(a)); - return IVP_PACKL2NX24(wide); + return IVP_SEL2NX8I(IVP_MOV2NX8_FROMNX16(b), IVP_MOV2NX8_FROMNX16(a), IVP_SELI_8B_EXTRACT_1_OF_2_OFF_0); } HALIDE_ALWAYS_INLINE native_vector_u8 halide_xtensa_convert_concat_u16_to_u8(const native_vector_u16 &a, const native_vector_u16 &b) { - xb_vec2Nx24 wide = IVP_CVT24U2NX16(xb_vecNx16U_rtor_xb_vecNx16(b), xb_vecNx16U_rtor_xb_vecNx16(a)); - return xb_vec2Nx8_rtor_xb_vec2Nx8U(IVP_PACKL2NX24(wide)); + return IVP_SEL2NX8UI(IVP_MOV2NX8_FROMNX16(b), IVP_MOV2NX8_FROMNX16(a), IVP_SELI_8B_EXTRACT_1_OF_2_OFF_0); } HALIDE_ALWAYS_INLINE native_vector_i16 halide_xtensa_convert_i8_low_i16(const native_vector_i8 &src, int native_lanes, int total_lines) { @@ -2919,3 +2918,121 @@ HALIDE_ALWAYS_INLINE HALIDE_MAYBE_UNUSED native_vector_f32_x2 gather_load +HALIDE_ALWAYS_INLINE native_vector_u8 +convert(const native_vector_u32_x4 &src) { + xb_vec2Nx24 wide = IVP_CVT24UNX32L(src.native_vector[1], src.native_vector[0]); + IVP_CVT24UNX32H(wide, src.native_vector[3], src.native_vector[2]); + return IVP_PACKL2NX24(wide); +} + +template<> +HALIDE_ALWAYS_INLINE native_vector_u32_x4 +convert(const native_vector_i24 &src) { + return native_vector_u32_x4(native_vector_u32_x4::from_native_vector, IVP_CVT32S2NX24LL(src), IVP_CVT32S2NX24LH(src), + IVP_CVT32S2NX24HL(src), IVP_CVT32S2NX24HH(src)); +} + +HALIDE_ALWAYS_INLINE native_vector_u32 +halide_xtensa_div_32_by_low16_of_32(native_vector_u32& a, native_vector_u32& b) { + native_vector_u32 quotient, remainder; + IVP_DIVN_2X32X16U(quotient, remainder, a, IVP_MOVNX16_FROMN_2X32(b), 0); + return quotient; +} + +HALIDE_ALWAYS_INLINE native_vector_u32 +halide_xtensa_div32(native_vector_u32 dividend, native_vector_u32 divisor) { + xb_vecN_2x32Uv nsa; + xb_vecNx16U vec_divisor; + xb_vecN_2x32Uv quotent; + xb_vecN_2x32Uv reminder; + vboolN_2 predicate; + + nsa = IVP_NSAUN_2X32U(divisor); + predicate = IVP_LTUN_2X32U(16, nsa); + nsa = IVP_MOVN_2X32UT(0, (xb_vecN_2x32Uv)16 - nsa, predicate); + xb_vecN_2x32Uv divisor_nsa = IVP_SRLN_2X32U(divisor, nsa); + + vec_divisor = IVP_MOVNX16_FROMN_2X32U(divisor_nsa); + IVP_DIVN_2X32X16U(quotent, reminder, dividend, vec_divisor, 0); + quotent = IVP_SRLN_2X32U(quotent, nsa); + + xb_vecN_2x64w dividend_wide = IVP_MULUUN_2X16X32_0(IVP_MOVNX16_FROMN_2X32U(quotent), divisor); + xb_vecN_2x32Uv dividend_tmp = IVP_PACKLN_2X96(dividend_wide); + predicate = IVP_LTUN_2X32U(dividend, dividend_tmp); + IVP_SUBN_2X32UT(quotent, quotent, 1, predicate); + return quotent; +} + +HALIDE_ALWAYS_INLINE native_vector_u16 +halide_xtensa_narrow_with_rounding_shift_u16(const native_vector_u32_x2 &a, uint32_t shift) { + xb_vecNx48 wide = convert(a); + // Add rounding factor. + native_vector_u16 v1 = IVP_SLLNX16U(1, (shift - 1)); + IVP_MULUUANX16(wide, v1, 1); + return xb_vecNx16_rtor_xb_vecNx16U(IVP_PACKVRNRNX48(wide, shift)); +} + +HALIDE_ALWAYS_INLINE native_vector_u16 +halide_xtensa_narrow_i48_with_rounding_shift_u16(const native_vector_i48 &a, uint32_t shift) { + xb_vecNx48 wide = a; + if (15 == shift) { + return IVP_PACKQNX48(a); + } + // Add rounding factor. + native_vector_u16 v1 = IVP_SLLNX16U(1, (shift - 1)); + IVP_MULUUANX16(wide, v1, 1); + return xb_vecNx16_rtor_xb_vecNx16U(IVP_PACKVRNRNX48(wide, shift)); +} + +HALIDE_ALWAYS_INLINE native_vector_i48 +halide_xtensa_widen_mul_sub_i48(const native_vector_i48 &a, const native_vector_i16 &b, const native_vector_i16 &c) { + native_vector_i48 r = a; + IVP_MULSNX16(r, b, c); + return r; +} + +template<> +HALIDE_ALWAYS_INLINE HALIDE_MAYBE_UNUSED native_vector_u8 +gather_load(const void *base, const native_vector_i16_x2& offset) { + auto addresses1 = xb_vecNx16_rtor_xb_vecNx16U(offset.native_vector[0]); + auto output1 = IVP_GATHERDNX8U( + IVP_GATHERANX8U( + (const uint8_t*) base, + (addresses1) + ) + ); + + auto addresses2 = xb_vecNx16_rtor_xb_vecNx16U(offset.native_vector[1]); + auto output2 = IVP_GATHERDNX8U( + IVP_GATHERANX8U( + (const uint8_t*) base, + (addresses2) + ) + ); + + // NOTE(aelphy): the intrinsic for gathering 8-bit elements extends them to 16-bit, and the conversion back to 8-bit is needed + return convert(native_vector_u16_x2(native_vector_u16_x2::from_native_vector, output1, output2)); +} diff --git a/src/XtensaOptimize.cpp b/src/XtensaOptimize.cpp index 4f9f212f7bb3..76357ff5fcbc 100644 --- a/src/XtensaOptimize.cpp +++ b/src/XtensaOptimize.cpp @@ -590,6 +590,33 @@ class MatchXtensaPatterns : public IRGraphMutator { return call; } + static Expr halide_xtensa_widen_add_u24(Expr v0, Expr v1) { + Expr call = Call::make(wild_i24x.type(), "halide_xtensa_widen_add_u24", {std::move(v0), std::move(v1)}, Call::PureExtern); + return call; + } + + static Expr halide_xtensa_widen_accum_u24(Expr v0, Expr v1) { + Expr call = Call::make(wild_i24x.type(), "halide_xtensa_widen_accum_u24", {std::move(v0), std::move(v1)}, Call::PureExtern); + return call; + } + + static Expr halide_xtensa_widen_mul_add_u24(Expr v0, Expr v1, Expr v2) { + Expr call = Call::make(wild_i24x.type(), "halide_xtensa_widen_mul_add_u24", {std::move(v0), std::move(v1), std::move(v2)}, Call::PureExtern); + return call; + } + + static Expr halide_xtensa_widen_pair_mul_add_u24(Expr w, Expr v0, Expr v1, Expr v2, Expr v3) { + Expr call = Call::make(wild_i24x.type(), "halide_xtensa_widen_pair_mul_add_u24", + {std::move(w), std::move(v0), std::move(v1), std::move(v2), std::move(v3)}, + Call::PureExtern); + return call; + } + + static Expr halide_xtensa_widen_mul_sub_i48(Expr v0, Expr v1, Expr v2) { + Expr call = Call::make(wild_i48x.type(), "halide_xtensa_widen_mul_sub_i48", {std::move(v0), std::move(v1), std::move(v2)}, Call::PureExtern); + return call; + } + Expr visit(const Add *op) override { if (op->type.is_vector()) { static const std::vector adds = { @@ -631,7 +658,7 @@ class MatchXtensaPatterns : public IRGraphMutator { wild_i24x + call("halide_xtensa_widen_mul_i24", wild_i24x, {wild_i8x, wild_i8x})}, {"halide_xtensa_widen_quad_mul_add_i24", - wild_i24x + call("halide_xtensa_widen_quad_mul_i24", wild_i24x, {wild_i8x, wild_i8x, wild_i8x, wild_i8x, wild_i8x})}, + wild_i24x + call("halide_xtensa_widen_quad_mul_i24", wild_i24x, {wild_i8x, wild_i8x, wild_i8x, wild_i8x, wild_i8})}, // Add to accumulator type. // Paired add. @@ -651,6 +678,14 @@ class MatchXtensaPatterns : public IRGraphMutator { {"halide_xtensa_widen_mul_add_i64", widening_mul(wild_i32x, wild_i32x) + bc(wild_i64), Pattern::NarrowOp2 | Pattern::AccumulatorOutput64}, {"halide_xtensa_widen_mul_add_i64", widening_mul(wild_i32x, wild_i32x) + wild_i64x, Pattern::NarrowOp2 | Pattern::AccumulatorOutput64}, {"halide_xtensa_widen_mul_add_i64", i32(wild_i64x) + i32(call("halide_xtensa_mul_i32", wild_i64x, {wild_i32x, wild_i32x})), Pattern::AccumulatorOutput64}, + + {"halide_xtensa_widen_pair_mul_add_u24", i16(halide_xtensa_widen_mul_add_u24(wild_i24x, wild_u8x, wild_u8x)) + i16(halide_xtensa_widen_mul_u24(wild_u8x, wild_u8x)), Pattern::AccumulatorOutput24}, + {"halide_xtensa_widen_pair_mul_add_u24", halide_xtensa_widen_mul_add_u24(wild_i24x, wild_u8x, wild_u8x) + halide_xtensa_widen_mul_u24(wild_u8x, wild_u8x)}, + + {"halide_xtensa_mul_add_u16", wild_u16x + wild_u16x*wild_u16x}, + + {"halide_xtensa_widen_add_u24", i24(wild_u8x) + i24(wild_u8x) , Pattern::AccumulatorOutput24}, + {"halide_xtensa_widen_accum_u24", wild_i24x + i24(wild_u8x) , Pattern::AccumulatorOutput24}, }; Expr new_expr = apply_commutative_patterns(op, adds, this); @@ -673,6 +708,8 @@ class MatchXtensaPatterns : public IRGraphMutator { // {"halide_xtensa_pred_sub_i16", wild_i16x - select(wild_u1x, wild_i16x, wild_i16x)}, // {"halide_xtensa_pred_sub_i32", wild_i32x - select(wild_u1x, wild_i32x, wild_i32x)}, {"halide_xtensa_widen_mul_sub_u24", wild_i24x - halide_xtensa_widen_mul_u24(wild_u8x, wild_u8x)}, + {"halide_xtensa_widen_mul_sub_i48", i32(wild_i48x) - i32(halide_xtensa_widen_mul_i48(wild_i16x, wild_i16x)), Pattern::AccumulatorOutput48}, + {"halide_xtensa_widen_mul_sub_i48", wild_i48x - halide_xtensa_widen_mul_i48(wild_i16x, wild_i16x)}, }; Expr new_expr = apply_patterns(op, subs, this); @@ -868,6 +905,7 @@ class MatchXtensaPatterns : public IRGraphMutator { {"halide_xtensa_convert_concat_i32_to_u16", u16(halide_xtensa_concat_from_native_i32(wild_i32x, wild_i32x))}, {"halide_xtensa_convert_concat_u32_to_i16", i16(halide_xtensa_concat_from_native_u32(wild_u32x, wild_u32x))}, {"halide_xtensa_convert_concat_u32_to_u16", u16(halide_xtensa_concat_from_native_u32(wild_u32x, wild_u32x))}, + {"halide_xtensa_narrow_with_rounding_shift_u16", u16(rounding_shift_right(wild_u32x, bc(wild_u32)))}, }; if (op->type.is_vector()) { Expr cast = op; @@ -952,11 +990,18 @@ class MatchXtensaPatterns : public IRGraphMutator { // that they generate. internal_assert(op->args.size() == 3); return mutate(lower_lerp(op->type, op->args[0], op->args[1], op->args[2], target)); - } else if (op->is_intrinsic(Call::absd) && op->type.is_vector() && op->type.is_uint() && (op->type.bits() == 16)) { + } else if (op->is_intrinsic(Call::absd) && op->type.is_vector() && op->type.is_uint()) { internal_assert(op->args.size() == 2); - return Call::make(op->type, "halide_xtensa_absd_i16", + + if (op->type.bits() == 16) { + return Call::make(op->type, "halide_xtensa_absd_i16", + {mutate(op->args[0]), mutate(op->args[1])}, + Call::PureExtern); + } else if (op->type.bits() == 8) { + return Call::make(op->type, "halide_xtensa_absd_u8", {mutate(op->args[0]), mutate(op->args[1])}, Call::PureExtern); + } } else if (op->is_intrinsic(Call::widening_shift_left)) { // Replace widening left shift with multiplication. const uint64_t *c = as_const_uint(op->args[1]); @@ -1069,8 +1114,7 @@ class MatchXtensaPatterns : public IRGraphMutator { {"halide_xtensa_widen_quad_mul_add_i24", call("halide_xtensa_widen_pair_mul_add_i24", wild_i24x, {call("halide_xtensa_widen_pair_mul_add_i24", wild_i24x, {wild_i24x, wild_i8x, wild_i8, wild_i8x, wild_i8}), wild_i8x, wild_i8, wild_i8x, wild_i8})}, {"halide_xtensa_widen_pair_mul_add_i24", - call("halide_xtensa_widen_mul_add_i24", wild_i24x, {call("halide_xtensa_widen_mul_add_i24", wild_i24x, {wild_i24x, wild_i8x, wild_i8}), wild_i8x, wild_i8})}, - + call("halide_xtensa_widen_mul_add_i24", wild_i24x, {call("halide_xtensa_widen_mul_add_i24", wild_i24x, {wild_i24x, wild_i8x, wild_i8x}), wild_i8x, wild_i8x})}, {"halide_xtensa_widen_pair_mul_add_i48", call("halide_xtensa_widen_mul_add_i48", wild_i48x, {call("halide_xtensa_widen_mul_add_i48", wild_i48x, {wild_i48x, wild_i16x, wild_i16x}), wild_i16x, wild_i16x})}, @@ -1115,6 +1159,14 @@ class MatchXtensaPatterns : public IRGraphMutator { {"halide_xtensa_narrow_i48_with_shift_i32", i32(wild_i48x) >> wild_i32}, {"halide_xtensa_narrow_i48_with_shift_u32", u32(wild_i48x) >> wild_u32}, + {"halide_xtensa_widen_add_u24", widening_add(wild_u8x, wild_u8x), Pattern::AccumulatorOutput24}, + {"halide_xtensa_widen_accum_u24", widening_add(wild_i24x, wild_u8x), Pattern::AccumulatorOutput24}, + + {"halide_xtensa_widen_pair_mul_add_u24", + call("halide_xtensa_widen_mul_add_u24", wild_i24x, + {call("halide_xtensa_widen_mul_add_u24", wild_i24x, {wild_i24x, wild_u8x, wild_u8x}), wild_u8x, wild_u8x})}, + {"halide_xtensa_narrow_i48_with_rounding_shift_u16", call("halide_xtensa_narrow_with_rounding_shift_u16", wild_u16x, {u32(wild_i48x), wild_u32})}, + // Predicated saturated add/sub. // NOTE(vksnk): patterns below are for predicated instructions and look like they may // be more efficient, but they are not according to simulator. We will need to check with