From c4af2cd30c8416e657b0dc2e7312d72baeced0c9 Mon Sep 17 00:00:00 2001 From: Adrien Cassagne Date: Mon, 2 Jul 2018 18:02:44 +0200 Subject: [PATCH] Add 16-bit and 8-bit mullo. --- src/mipp_impl_AVX.hxx | 5 +++ src/mipp_impl_AVX512.hxx | 7 ++++ src/mipp_impl_NEON.hxx | 10 +++++ src/mipp_impl_SSE.hxx | 7 ++++ tests/src/arithmetic_operations/mul.cpp | 54 +++++++++++++++++++++++++ 5 files changed, 83 insertions(+) diff --git a/src/mipp_impl_AVX.hxx b/src/mipp_impl_AVX.hxx index bad83b4..05d656e 100644 --- a/src/mipp_impl_AVX.hxx +++ b/src/mipp_impl_AVX.hxx @@ -2085,6 +2085,11 @@ inline reg mul(const reg v1, const reg v2) { return _mm256_castsi256_ps(_mm256_mullo_epi32(_mm256_castps_si256(v1), _mm256_castps_si256(v2))); } + + template <> + inline reg mul(const reg v1, const reg v2) { + return _mm256_castsi256_ps(_mm256_mullo_epi16(_mm256_castps_si256(v1), _mm256_castps_si256(v2))); + } #endif // ------------------------------------------------------------------------------------------------------------ div diff --git a/src/mipp_impl_AVX512.hxx b/src/mipp_impl_AVX512.hxx index c5b032b..7167f7d 100644 --- a/src/mipp_impl_AVX512.hxx +++ b/src/mipp_impl_AVX512.hxx @@ -2563,6 +2563,13 @@ return _mm512_castsi512_ps(_mm512_mullo_epi32(_mm512_castps_si512(v1), _mm512_castps_si512(v2))); } +#if defined(__AVX512BW__) + template <> + inline reg mul(const reg v1, const reg v2) { + return _mm512_castsi512_ps(_mm512_mullo_epi16(_mm512_castps_si512(v1), _mm512_castps_si512(v2))); + } +#endif + // ------------------------------------------------------------------------------------------------------------ div #if defined(__AVX512F__) template <> diff --git a/src/mipp_impl_NEON.hxx b/src/mipp_impl_NEON.hxx index c781aa7..d40a6e8 100644 --- a/src/mipp_impl_NEON.hxx +++ b/src/mipp_impl_NEON.hxx @@ -1798,6 +1798,16 @@ return (reg) vmulq_s32((int32x4_t) v1, (int32x4_t) v2); } + template <> + inline reg mul(const reg v1, const reg v2) { + return (reg) vmulq_s16((int16x8_t) v1, (int16x8_t) v2); + } + + template <> + inline reg mul(const reg v1, const reg v2) { + return (reg) vmulq_s8((int8x16_t) v1, (int8x16_t) v2); + } + // ------------------------------------------------------------------------------------------------------------ div #ifdef __aarch64__ template <> diff --git a/src/mipp_impl_SSE.hxx b/src/mipp_impl_SSE.hxx index c6d2897..721aa07 100644 --- a/src/mipp_impl_SSE.hxx +++ b/src/mipp_impl_SSE.hxx @@ -1973,6 +1973,13 @@ } #endif +#ifdef __SSE2__ + template <> + inline reg mul(const reg v1, const reg v2) { + return _mm_castsi128_ps(_mm_mullo_epi16(_mm_castps_si128(v1), _mm_castps_si128(v2))); + } +#endif + // ------------------------------------------------------------------------------------------------------------ div template <> inline reg div(const reg v1, const reg v2) { diff --git a/tests/src/arithmetic_operations/mul.cpp b/tests/src/arithmetic_operations/mul.cpp index ced97b5..ca43b4d 100644 --- a/tests/src/arithmetic_operations/mul.cpp +++ b/tests/src/arithmetic_operations/mul.cpp @@ -30,6 +30,26 @@ void test_reg_mul() REQUIRE(*((T*)&r3 +i) == res); #endif } + + std::iota(inputs1, inputs1 + mipp::N(), std::numeric_limits::max() - mipp::N()); + std::iota(inputs2, inputs2 + mipp::N(), std::numeric_limits::max() - mipp::N()); + + std::shuffle(inputs1, inputs1 + mipp::N(), g); + std::shuffle(inputs2, inputs2 + mipp::N(), g); + + r1 = mipp::load(inputs1); + r2 = mipp::load(inputs2); + r3 = mipp::mul (r1, r2); + + for (auto i = 0; i < mipp::N(); i++) + { + T res = inputs1[i] * inputs2[i]; +#if defined(MIPP_NEON) && MIPP_INSTR_VERSION == 1 + REQUIRE(*((T*)&r3 +i) == Approx(res)); +#else + REQUIRE(*((T*)&r3 +i) == res); +#endif + } } #ifndef MIPP_NO @@ -44,6 +64,12 @@ TEST_CASE("Multiplication - mipp::reg", "[mipp::mul]") #if !defined(MIPP_SSE) || (defined(MIPP_SSE) && MIPP_INSTR_VERSION >= 41) SECTION("datatype = int32_t") { test_reg_mul(); } #endif +#if !defined(MIPP_SSE) || (defined(MIPP_SSE) && MIPP_INSTR_VERSION >= 2) + SECTION("datatype = int16_t") { test_reg_mul(); } +#endif +#endif +#if defined(MIPP_NEON) + SECTION("datatype = int8_t") { test_reg_mul(); } #endif } #endif @@ -72,8 +98,29 @@ void test_Reg_mul() REQUIRE(r3[i] == res); #endif } + + std::iota(inputs1, inputs1 + mipp::N(), std::numeric_limits::max() - mipp::N()); + std::iota(inputs2, inputs2 + mipp::N(), std::numeric_limits::max() - mipp::N()); + + std::shuffle(inputs1, inputs1 + mipp::N(), g); + std::shuffle(inputs2, inputs2 + mipp::N(), g); + + r1 = inputs1; + r2 = inputs2; + r3 = r1 * r2; + + for (auto i = 0; i < mipp::N(); i++) + { + T res = inputs1[i] * inputs2[i]; +#if defined(MIPP_NEON) && MIPP_INSTR_VERSION == 1 + REQUIRE(r3[i] == Approx(res)); +#else + REQUIRE(r3[i] == res); +#endif + } } +#ifndef MIPP_NO TEST_CASE("Multiplication - mipp::Reg", "[mipp::mul]") { #if defined(MIPP_64BIT) @@ -85,8 +132,15 @@ TEST_CASE("Multiplication - mipp::Reg", "[mipp::mul]") #if !defined(MIPP_SSE) || (defined(MIPP_SSE) && MIPP_INSTR_VERSION >= 41) SECTION("datatype = int32_t") { test_Reg_mul(); } #endif +#if !defined(MIPP_SSE) || (defined(MIPP_SSE) && MIPP_INSTR_VERSION >= 2) + SECTION("datatype = int16_t") { test_Reg_mul(); } +#endif +#endif +#if defined(MIPP_NEON) + SECTION("datatype = int8_t") { test_Reg_mul(); } #endif } +#endif template void test_reg_maskz_mul()