Skip to content

Commit

Permalink
AffineTransformでkOutputDimensionsが4で割り切れない場合でもAVX2やAVX-512を使えるようにした。
Browse files Browse the repository at this point in the history
  • Loading branch information
KazApps committed Dec 4, 2024
1 parent f25d633 commit d4490e5
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 19 deletions.
104 changes: 91 additions & 13 deletions source/eval/nnue/layers/affine_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,105 @@
namespace Eval::NNUE::Layers {

template<IndexType kInputDimensions, IndexType kPaddedInputDimensions, IndexType kOutputDimensions>
static void affine_transform_non_ssse3(std::int32_t* output,
static void affine_transform_unaligned(std::int32_t* output,
const std::int8_t* weights,
const std::int32_t* biases,
const std::uint8_t* input) {
#if defined(USE_SSSE3) || defined(USE_NEON_DOTPROD)
#if defined(USE_SSE2)
#if defined(USE_SSE2) || defined(USE_NEON)

#if defined(USE_AVX512)
constexpr IndexType kNumChunks = CeilToMultiple<IndexType>(kInputDimensions, 64) / 64;
const __m512i kZeros = _mm512_setzero_si512();
const auto inputVector = reinterpret_cast<const __m512i*>(input);
#elif defined(USE_AVX2)
constexpr IndexType kNumChunks = CeilToMultiple<IndexType>(kInputDimensions, 32) / 32;
const __m256i kZeros = _mm256_setzero_si256();
const auto inputVector = reinterpret_cast<const __m256i*>(input);
#elif defined(USE_SSE2)
// At least a multiple of 16, with SSE2.
constexpr IndexType kNumChunks = CeilToMultiple<IndexType>(kInputDimensions, 16) / 16;
const __m128i kZeros = _mm_setzero_si128();
constexpr IndexType kNumChunks = CeilToMultiple<IndexType>(kInputDimensions, 16) / 16;
const __m128i kZeros = _mm_setzero_si128();
const auto inputVector = reinterpret_cast<const __m128i*>(input);

#elif defined(USE_NEON)
constexpr IndexType kNumChunks = CeilToMultiple<IndexType>(kInputDimensions, 16) / 16;
constexpr IndexType kNumChunks = CeilToMultiple<IndexType>(kInputDimensions, 16) / 16;
const auto inputVector = reinterpret_cast<const int8x8_t*>(input);
#endif

for (IndexType i = 0; i < kOutputDimensions; ++i)
{
const IndexType offset = i * kPaddedInputDimensions;

#if defined(USE_SSE2)
#if defined(USE_AVX512)

__m512i sumLo = _mm512_castsi128_si512(_mm_cvtsi32_si128(biases[i]));
__m512i sumHi = _mm512_setzero_si512();
const auto row = reinterpret_cast<const __m512i*>(&weights[offset]);

for (IndexType j = 0; j < kNumChunks; ++j)
{
__m512i row_j = _mm512_load_si512(&row[j]);
__m512i input_j = _mm512_load_si512(reinterpret_cast<const __m512i*>(&inputVector[j]));
__m512i extendedRowLo = _mm512_srai_epi16(_mm512_unpacklo_epi8(row_j, row_j), 8);
__m512i extendedRowHi = _mm512_srai_epi16(_mm512_unpackhi_epi8(row_j, row_j), 8);
__m512i extendedInputLo = _mm512_unpacklo_epi8(input_j, _mm512_setzero_si512());
__m512i extendedInputHi = _mm512_unpackhi_epi8(input_j, _mm512_setzero_si512());
__m512i productLo = _mm512_madd_epi16(extendedRowLo, extendedInputLo);
__m512i productHi = _mm512_madd_epi16(extendedRowHi, extendedInputHi);
sumLo = _mm512_add_epi32(sumLo, productLo);
sumHi = _mm512_add_epi32(sumHi, productHi);
}

__m512i sum = _mm512_add_epi32(sumLo, sumHi);
__m256i sumLow256 = _mm512_castsi512_si256(sum);
__m256i sumHigh256 = _mm512_extracti64x4_epi64(sum, 1);
__m256i finalSum256 = _mm256_add_epi32(sumLow256, sumHigh256);

__m128i sumLow128 = _mm256_castsi256_si128(finalSum256);
__m128i sumHigh128 = _mm256_extracti128_si256(finalSum256, 1);
__m128i finalSum = _mm_add_epi32(sumLow128, sumHigh128);

__m128i sumHigh_64 = _mm_shuffle_epi32(finalSum, _MM_SHUFFLE(1, 0, 3, 2));
finalSum = _mm_add_epi32(finalSum, sumHigh_64);
__m128i sum_second_32 = _mm_shufflelo_epi16(finalSum, _MM_SHUFFLE(1, 0, 3, 2));
finalSum = _mm_add_epi32(finalSum, sum_second_32);
output[i] = _mm_cvtsi128_si32(finalSum);

#elif defined(USE_AVX2)

__m256i sumLo = _mm256_castsi128_si256(_mm_cvtsi32_si128(biases[i]));
__m256i sumHi = _mm256_setzero_si256();
const auto row = reinterpret_cast<const __m256i*>(&weights[offset]);

for (IndexType j = 0; j < kNumChunks; ++j)
{
__m256i row_j = _mm256_load_si256(&row[j]);
__m256i input_j = _mm256_load_si256(reinterpret_cast<const __m256i*>(&inputVector[j]));
__m256i extendedRowLo = _mm256_srai_epi16(_mm256_unpacklo_epi8(row_j, row_j), 8);
__m256i extendedRowHi = _mm256_srai_epi16(_mm256_unpackhi_epi8(row_j, row_j), 8);
__m256i extendedInputLo = _mm256_unpacklo_epi8(input_j, _mm256_setzero_si256());
__m256i extendedInputHi = _mm256_unpackhi_epi8(input_j, _mm256_setzero_si256());
__m256i productLo = _mm256_madd_epi16(extendedRowLo, extendedInputLo);
__m256i productHi = _mm256_madd_epi16(extendedRowHi, extendedInputHi);
sumLo = _mm256_add_epi32(sumLo, productLo);
sumHi = _mm256_add_epi32(sumHi, productHi);
}

__m256i sum = _mm256_add_epi32(sumLo, sumHi);
__m128i sumLow128 = _mm256_castsi256_si128(sum);
__m128i sumHigh128 = _mm256_extracti128_si256(sum, 1);
__m128i finalSum = _mm_add_epi32(sumLow128, sumHigh128);
__m128i sumHigh_64 = _mm_shuffle_epi32(finalSum, _MM_SHUFFLE(1, 0, 3, 2));
finalSum = _mm_add_epi32(finalSum, sumHigh_64);
__m128i sum_second_32 = _mm_shufflelo_epi16(finalSum, _MM_SHUFFLE(1, 0, 3, 2));
finalSum = _mm_add_epi32(finalSum, sum_second_32);
output[i] = _mm_cvtsi128_si32(finalSum);

#elif defined(USE_SSE2)

__m128i sumLo = _mm_cvtsi32_si128(biases[i]);
__m128i sumHi = kZeros;
const auto row = reinterpret_cast<const __m128i*>(&weights[offset]);

for (IndexType j = 0; j < kNumChunks; ++j)
{
__m128i row_j = _mm_load_si128(&row[j]);
Expand All @@ -51,6 +126,7 @@ static void affine_transform_non_ssse3(std::int32_t* output,
sumLo = _mm_add_epi32(sumLo, productLo);
sumHi = _mm_add_epi32(sumHi, productHi);
}

__m128i sum = _mm_add_epi32(sumLo, sumHi);
__m128i sumHigh_64 = _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2));
sum = _mm_add_epi32(sum, sumHigh_64);
Expand All @@ -62,12 +138,14 @@ static void affine_transform_non_ssse3(std::int32_t* output,

int32x4_t sum = {biases[i]};
const auto row = reinterpret_cast<const int8x8_t*>(&weights[offset]);

for (IndexType j = 0; j < kNumChunks; ++j)
{
int16x8_t product = vmull_s8(inputVector[j * 2], row[j * 2]);
product = vmlal_s8(product, inputVector[j * 2 + 1], row[j * 2 + 1]);
sum = vpadalq_s16(sum, product);
}

output[i] = sum[0] + sum[1] + sum[2] + sum[3];

#endif
Expand Down Expand Up @@ -129,14 +207,14 @@ class AffineTransform {
PreviousLayer::GetStructureString() + ")";
}

static constexpr IndexType get_weight_index_scrambled(IndexType i) {
static constexpr IndexType GetWeightIndexScrambled(IndexType i) {
return (i / 4) % (kPaddedInputDimensions / 4) * kOutputDimensions * 4
+ i / kPaddedInputDimensions * 4 + i % 4;
}

static constexpr IndexType get_weight_index(IndexType i) {
static constexpr IndexType GetWeightIndex(IndexType i) {
#if defined(USE_SSSE3) || defined(USE_NEON_DOTPROD)
return kOutputDimensions % 4 == 0 ? get_weight_index_scrambled(i) : i;
return kOutputDimensions % 4 == 0 ? GetWeightIndexScrambled(i) : i;
#else
return i;
#endif
Expand All @@ -150,7 +228,7 @@ class AffineTransform {
for (std::size_t i = 0; i < kOutputDimensions; ++i)
biases_[i] = read_little_endian<BiasType>(stream);
for (std::size_t i = 0; i < kOutputDimensions * kPaddedInputDimensions; ++i)
weights_[get_weight_index(IndexType(i))] = read_little_endian<WeightType>(stream);
weights_[GetWeightIndex(IndexType(i))] = read_little_endian<WeightType>(stream);
return !stream.fail() ? Tools::ResultCode::Ok : Tools::ResultCode::FileReadError;
}

Expand Down Expand Up @@ -310,7 +388,7 @@ class AffineTransform {
else
#endif

affine_transform_non_ssse3<kInputDimensions, kPaddedInputDimensions, kOutputDimensions>(
affine_transform_unaligned<kInputDimensions, kPaddedInputDimensions, kOutputDimensions>(
output, weights_, biases_, input);
}
else if constexpr (kOutputDimensions == 1)
Expand Down
12 changes: 6 additions & 6 deletions source/eval/nnue/layers/affine_transform_sparse_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,14 @@ class AffineTransformSparseInput {
PreviousLayer::GetStructureString() + ")";
}

static constexpr IndexType get_weight_index_scrambled(IndexType i) {
static constexpr IndexType GetWeightIndexScrambled(IndexType i) {
return (i / kChunkSize) % (kPaddedInputDimensions / kChunkSize) * kOutputDimensions * kChunkSize
+ i / kPaddedInputDimensions * kChunkSize + i % kChunkSize;
}

static constexpr IndexType get_weight_index(IndexType i) {
static constexpr IndexType GetWeightIndex(IndexType i) {
#if defined(USE_SSSE3) || USE_NEON >= 8
return kOutputDimensions % 4 == 0 ? get_weight_index_scrambled(i) : i;
return kOutputDimensions % 4 == 0 ? GetWeightIndexScrambled(i) : i;
#else
return i;
#endif
Expand All @@ -178,7 +178,7 @@ class AffineTransformSparseInput {
for (std::size_t i = 0; i < kOutputDimensions; ++i)
biases_[i] = read_little_endian<BiasType>(stream);
for (std::size_t i = 0; i < kOutputDimensions * kPaddedInputDimensions; ++i)
weights_[get_weight_index(IndexType(i))] = read_little_endian<WeightType>(stream);
weights_[GetWeightIndex(IndexType(i))] = read_little_endian<WeightType>(stream);
return !stream.fail() ? Tools::ResultCode::Ok : Tools::ResultCode::FileReadError;
}

Expand Down Expand Up @@ -363,15 +363,15 @@ class AffineTransformSparseInput {
}
else
#endif
affine_transform_non_ssse3<kInputDimensions, kPaddedInputDimensions, kOutputDimensions>(
affine_transform_unaligned<kInputDimensions, kPaddedInputDimensions, kOutputDimensions>(
output, weights_, biases_, input);

#undef vec_set_32
#undef vec_add_dpbusd_32

#else
// Use dense implementation for the other architectures.
affine_transform_non_ssse3<kInputDimensions, kPaddedInputDimensions, kOutputDimensions>(
affine_transform_unaligned<kInputDimensions, kPaddedInputDimensions, kOutputDimensions>(
output, weights_, biases_, input);
#endif

Expand Down

0 comments on commit d4490e5

Please sign in to comment.