Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
no functional change
  • Loading branch information
xu-shawn committed Jan 15, 2025
1 parent c085670 commit 62eec0b
Showing 1 changed file with 65 additions and 54 deletions.
119 changes: 65 additions & 54 deletions src/nnue/nnue_feature_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <cstdint>
#include <cstring>
#include <iosfwd>
#include <type_traits>
#include <utility>

#include "../position.h"
Expand Down Expand Up @@ -124,8 +125,7 @@ using psqt_vec_t = int32x4_t;
#define vec_add_16(a, b) vaddq_s16(a, b)
#define vec_sub_16(a, b) vsubq_s16(a, b)
#define vec_mulhi_16(a, b) vqdmulhq_s16(a, b)
#define vec_zero() \
vec_t { 0 }
#define vec_zero() vec_t{0}
#define vec_set_16(a) vdupq_n_s16(a)
#define vec_max_16(a, b) vmaxq_s16(a, b)
#define vec_min_16(a, b) vminq_s16(a, b)
Expand All @@ -135,8 +135,7 @@ using psqt_vec_t = int32x4_t;
#define vec_store_psqt(a, b) *(a) = (b)
#define vec_add_psqt_32(a, b) vaddq_s32(a, b)
#define vec_sub_psqt_32(a, b) vsubq_s32(a, b)
#define vec_zero_psqt() \
psqt_vec_t { 0 }
#define vec_zero_psqt() psqt_vec_t{0}
#define NumRegistersSIMD 16
#define MaxChunkSize 16

Expand All @@ -146,6 +145,58 @@ using psqt_vec_t = int32x4_t;
#endif


struct Packing {
// Returns the order by which 128-bit blocks of a 1024-bit data must
// be permuted so that calling packus on adjacent vectors of 16-bit
// integers loaded from the data results in the pre-permutation order
static constexpr auto packus_epi16_order = []() -> std::array<std::size_t, 8> {
#if defined(USE_AVX512)
// _mm512_packus_epi16 after permutation:
// 0 2 4 6 // Vector 0
// 1 3 5 7 // Vector 1
// 0 1 2 3 4 5 6 7 // Packed Result
return {0, 2, 4, 6, 1, 3, 5, 7};
#elif defined(USE_AVX2)
// _mm256_packus_epi16 after permutation:
// 0 2 // Vector 0 // 4 6 // Vector 2
// 1 3 // Vector 1 // 5 7 // Vector 3
// 0 1 2 3 // Packed Result // 4 5 6 7 // Packed Result
return {0, 2, 1, 3, 4, 6, 5, 7};
#else
return {0, 1, 2, 3, 4, 5, 6, 7};
#endif
}();

static constexpr std::size_t epi16_block_size = 8;
static constexpr std::size_t process_chunk = epi16_block_size * packus_epi16_order.size();

static constexpr auto permute_for_packus_epi16 = [](auto* const v) {
std::array<std::remove_pointer_t<decltype(v)>, epi16_block_size * packus_epi16_order.size()>
buffer;

for (std::size_t i = 0; i < packus_epi16_order.size(); i++)
for (std::size_t j = 0; j < epi16_block_size; j++)
buffer[i * epi16_block_size + j] = v[packus_epi16_order[i] * epi16_block_size + j];

for (std::size_t i = 0; i < buffer.size(); i++)
v[i] = buffer[i];
};


static constexpr auto unpermute_for_packus_epi16 = [](auto* const v) {
std::array<std::remove_pointer_t<decltype(v)>, epi16_block_size * packus_epi16_order.size()>
buffer;

for (std::size_t i = 0; i < packus_epi16_order.size(); i++)
for (std::size_t j = 0; j < epi16_block_size; j++)
buffer[packus_epi16_order[i] * epi16_block_size + j] = v[i * epi16_block_size + j];

for (std::size_t i = 0; i < buffer.size(); i++)
v[i] = buffer[i];
};
};


// Compute optimal SIMD register count for feature transformer accumulation.
template<IndexType TransformedFeatureWidth, IndexType HalfDimensions>
class SIMDTiling {
Expand Down Expand Up @@ -203,7 +254,7 @@ class SIMDTiling {

// Input feature converter
template<IndexType TransformedFeatureDimensions,
Accumulator<TransformedFeatureDimensions> StateInfo::*accPtr>
Accumulator<TransformedFeatureDimensions> StateInfo::* accPtr>
class FeatureTransformer {

// Number of output dimensions for one side
Expand All @@ -228,57 +279,17 @@ class FeatureTransformer {
return FeatureSet::HashValue ^ (OutputDimensions * 2);
}

static constexpr void order_packs([[maybe_unused]] uint64_t* v) {
#if defined(USE_AVX512) // _mm512_packs_epi16 ordering
uint64_t tmp0 = v[2], tmp1 = v[3];
v[2] = v[8], v[3] = v[9];
v[8] = v[4], v[9] = v[5];
v[4] = tmp0, v[5] = tmp1;
tmp0 = v[6], tmp1 = v[7];
v[6] = v[10], v[7] = v[11];
v[10] = v[12], v[11] = v[13];
v[12] = tmp0, v[13] = tmp1;
#elif defined(USE_AVX2) // _mm256_packs_epi16 ordering
std::swap(v[2], v[4]);
std::swap(v[3], v[5]);
#endif
}

static constexpr void inverse_order_packs([[maybe_unused]] uint64_t* v) {
#if defined(USE_AVX512) // Inverse _mm512_packs_epi16 ordering
uint64_t tmp0 = v[2], tmp1 = v[3];
v[2] = v[4], v[3] = v[5];
v[4] = v[8], v[5] = v[9];
v[8] = tmp0, v[9] = tmp1;
tmp0 = v[6], tmp1 = v[7];
v[6] = v[12], v[7] = v[13];
v[12] = v[10], v[13] = v[11];
v[10] = tmp0, v[11] = tmp1;
#elif defined(USE_AVX2) // Inverse _mm256_packs_epi16 ordering
std::swap(v[2], v[4]);
std::swap(v[3], v[5]);
#endif
}

void permute_weights([[maybe_unused]] void (*order_fn)(uint64_t*)) {
#if defined(USE_AVX2)
#if defined(USE_AVX512)
constexpr IndexType di = 16;
#else
constexpr IndexType di = 8;
#endif
uint64_t* b = reinterpret_cast<uint64_t*>(&biases[0]);
for (IndexType i = 0; i < HalfDimensions * sizeof(BiasType) / sizeof(uint64_t); i += di)
order_fn(&b[i]);
template<typename Function>
void permute_weights(Function order_fn) {
for (IndexType i = 0; i < HalfDimensions; i += Packing::process_chunk)
order_fn(&biases[i]);

for (IndexType j = 0; j < InputDimensions; ++j)
{
uint64_t* w = reinterpret_cast<uint64_t*>(&weights[j * HalfDimensions]);
for (IndexType i = 0; i < HalfDimensions * sizeof(WeightType) / sizeof(uint64_t);
i += di)
auto* w = &weights[j * HalfDimensions];
for (IndexType i = 0; i < HalfDimensions; i += Packing::process_chunk)
order_fn(&w[i]);
}
#endif
}

inline void scale_weights(bool read) {
Expand All @@ -300,22 +311,22 @@ class FeatureTransformer {
read_leb_128<WeightType>(stream, weights, HalfDimensions * InputDimensions);
read_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets * InputDimensions);

permute_weights(inverse_order_packs);
permute_weights(Packing::permute_for_packus_epi16);
scale_weights(true);
return !stream.fail();
}

// Write network parameters
bool write_parameters(std::ostream& stream) {

permute_weights(order_packs);
permute_weights(Packing::unpermute_for_packus_epi16);
scale_weights(false);

write_leb_128<BiasType>(stream, biases, HalfDimensions);
write_leb_128<WeightType>(stream, weights, HalfDimensions * InputDimensions);
write_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets * InputDimensions);

permute_weights(inverse_order_packs);
permute_weights(Packing::permute_for_packus_epi16);
scale_weights(true);
return !stream.fail();
}
Expand Down

0 comments on commit 62eec0b

Please sign in to comment.