From 2d3f3b0e1f9ed0d1db20b82ef0681ba83f2a9f74 Mon Sep 17 00:00:00 2001 From: Shawn Xu Date: Sun, 12 Jan 2025 19:38:02 -0800 Subject: [PATCH] cleanup no functional change --- src/nnue/nnue_feature_transformer.h | 91 +++++++++++++++++++++-------- 1 file changed, 66 insertions(+), 25 deletions(-) diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index 8649d9521bb..78cca7faff3 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -228,33 +228,74 @@ class FeatureTransformer { return FeatureSet::HashValue ^ (OutputDimensions * 2); } - static constexpr void order_packs([[maybe_unused]] uint64_t* v) { -#if defined(USE_AVX512) // _mm512_packs_epi16 ordering - uint64_t tmp0 = v[2], tmp1 = v[3]; - v[2] = v[8], v[3] = v[9]; - v[8] = v[4], v[9] = v[5]; - v[4] = tmp0, v[5] = tmp1; - tmp0 = v[6], tmp1 = v[7]; - v[6] = v[10], v[7] = v[11]; - v[10] = v[12], v[11] = v[13]; - v[12] = tmp0, v[13] = tmp1; -#elif defined(USE_AVX2) // _mm256_packs_epi16 ordering + static constexpr void order_packus([[maybe_unused]] uint64_t* v) { +#if defined(USE_AVX512) // _mm512_packus_epi16 ordering + std::swap(v[12], v[6]); + std::swap(v[13], v[7]); + + std::swap(v[6], v[10]); + std::swap(v[7], v[11]); + + std::swap(v[4], v[2]); + std::swap(v[5], v[3]); + + std::swap(v[2], v[8]); + std::swap(v[3], v[9]); +#elif defined(USE_AVX2) // _mm256_packus_epi16 ordering std::swap(v[2], v[4]); std::swap(v[3], v[5]); #endif } - static constexpr void inverse_order_packs([[maybe_unused]] uint64_t* v) { -#if defined(USE_AVX512) // Inverse _mm512_packs_epi16 ordering - uint64_t tmp0 = v[2], tmp1 = v[3]; - v[2] = v[4], v[3] = v[5]; - v[4] = v[8], v[5] = v[9]; - v[8] = tmp0, v[9] = tmp1; - tmp0 = v[6], tmp1 = v[7]; - v[6] = v[12], v[7] = v[13]; - v[12] = v[10], v[13] = v[11]; - v[10] = tmp0, v[11] = tmp1; -#elif defined(USE_AVX2) // Inverse _mm256_packs_epi16 ordering + static constexpr void inverse_order_packus([[maybe_unused]] uint64_t* v) { +#if defined(USE_AVX512) // Inverse _mm512_packus_epi16 ordering + + // Here, our goal is to concatenate two 512-bit + // vectors, without changing their order. To do + // this, we reorder the weights every 1024 bits + // by swapping the elements by 64-bit blocks. + + // Current _mm512_packus_epi16 order: + // 01 23 45 67 // Vector 0 + // 89 AB CD EF // Vector 1 + // 01 89 23 AB 45 CD 67 EF // Packed Result + + std::swap(v[2], v[8]); + std::swap(v[3], v[9]); + + // Current _mm512_packus_epi16 order: + // 01 89 45 67 // Vector 0 + // 23 AB CD EF // Vector 1 + // 01 23 89 AB 45 CD 67 EF // Packed Result + + std::swap(v[4], v[2]); + std::swap(v[5], v[3]); + + // Current _mm512_packus_epi16 order: + // 01 45 89 67 // Vector 0 + // 23 AB CD EF // Vector 1 + // 01 23 45 AB 89 CD 67 EF // Packed Result + + std::swap(v[6], v[10]); + std::swap(v[7], v[11]); + + // Current _mm512_packus_epi16 order: + // 01 45 89 AB // Vector 0 + // 23 67 CD EF // Vector 1 + // 01 23 45 67 89 CD AB EF // Packed Result + + std::swap(v[12], v[6]); // Now v[6] holds the original v[10] + std::swap(v[13], v[7]); // Now v[7] holds the original v[11] + + // clang-format off + + // Current _mm512_packus_epi16 order: + // 01 45 89 AB // Vector 0 + // 23 67 AB EF // Vector 1 + // 01 23 45 67 89 AB CD EF // Packed Result + + // clang-format on +#elif defined(USE_AVX2) // Inverse _mm256_packus_epi16 ordering std::swap(v[2], v[4]); std::swap(v[3], v[5]); #endif @@ -300,7 +341,7 @@ class FeatureTransformer { read_leb_128(stream, weights, HalfDimensions * InputDimensions); read_leb_128(stream, psqtWeights, PSQTBuckets * InputDimensions); - permute_weights(inverse_order_packs); + permute_weights(inverse_order_packus); scale_weights(true); return !stream.fail(); } @@ -308,14 +349,14 @@ class FeatureTransformer { // Write network parameters bool write_parameters(std::ostream& stream) { - permute_weights(order_packs); + permute_weights(order_packus); scale_weights(false); write_leb_128(stream, biases, HalfDimensions); write_leb_128(stream, weights, HalfDimensions * InputDimensions); write_leb_128(stream, psqtWeights, PSQTBuckets * InputDimensions); - permute_weights(inverse_order_packs); + permute_weights(inverse_order_packus); scale_weights(true); return !stream.fail(); }