From f35ef2068130beeba79c83fa16d96dfbd7e07397 Mon Sep 17 00:00:00 2001 From: KazApps Date: Fri, 13 Sep 2024 10:56:51 +0900 Subject: [PATCH] =?UTF-8?q?AVX-512=E3=81=A7=E3=82=AF=E3=83=A9=E3=83=83?= =?UTF-8?q?=E3=82=B7=E3=83=A5=E3=81=99=E3=82=8B=E5=95=8F=E9=A1=8C=E3=82=92?= =?UTF-8?q?=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- source/eval/nnue/nnue_accumulator.h | 3 ++- source/eval/nnue/nnue_feature_transformer.h | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/source/eval/nnue/nnue_accumulator.h b/source/eval/nnue/nnue_accumulator.h index c7c43a3e5..f6ed4194f 100644 --- a/source/eval/nnue/nnue_accumulator.h +++ b/source/eval/nnue/nnue_accumulator.h @@ -15,7 +15,8 @@ namespace NNUE { // 入力特徴量をアフィン変換した結果を保持するクラス // 最終的な出力である評価値も一緒に持たせておく -struct alignas(32) Accumulator { +// AVX-512命令を使用する場合に64bytesのアライメントが要求される。 +struct alignas(64) Accumulator { std::int16_t accumulation[2][kRefreshTriggers.size()][kTransformedFeatureDimensions]; Value score = VALUE_ZERO; diff --git a/source/eval/nnue/nnue_feature_transformer.h b/source/eval/nnue/nnue_feature_transformer.h index 8a722b34c..5b8dd8651 100644 --- a/source/eval/nnue/nnue_feature_transformer.h +++ b/source/eval/nnue/nnue_feature_transformer.h @@ -289,7 +289,11 @@ class FeatureTransformer { const IndexType offset = kHalfDimensions * index; auto accumulation = reinterpret_cast(&accumulator.accumulation[perspective][i][0]); auto column = reinterpret_cast(&weights_[offset]); +#if defined(USE_AVX512) + constexpr IndexType kNumChunks = kHalfDimensions / kSimdWidth; +#else constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2); +#endif for (IndexType j = 0; j < kNumChunks; ++j) { accumulation[j] = vec_add_16(accumulation[j], column[j]); } @@ -327,7 +331,11 @@ class FeatureTransformer { RawFeatures::AppendChangedIndices(pos, kRefreshTriggers[i], removed_indices, added_indices, reset); for (Color perspective : {BLACK, WHITE}) { #if defined(VECTOR) +#if defined(USE_AVX512) + constexpr IndexType kNumChunks = kHalfDimensions / kSimdWidth; +#else constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2); +#endif auto accumulation = reinterpret_cast(&accumulator.accumulation[perspective][i][0]); #endif if (reset[perspective]) {