diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index fa180678d89..d36a173c538 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -256,21 +256,22 @@ class FeatureTransformer { #endif } - void permute_weights([[maybe_unused]] void (*order_fn)(uint64_t*)) const { + static void permute_weights([[maybe_unused]] WeightType* weights, + [[maybe_unused]] BiasType* biases, + [[maybe_unused]] void (*order_fn)(uint64_t*)) { #if defined(USE_AVX2) #if defined(USE_AVX512) constexpr IndexType di = 16; #else constexpr IndexType di = 8; #endif - uint64_t* b = reinterpret_cast(const_cast(&biases[0])); + uint64_t* b = reinterpret_cast(&biases[0]); for (IndexType i = 0; i < HalfDimensions * sizeof(BiasType) / sizeof(uint64_t); i += di) order_fn(&b[i]); for (IndexType j = 0; j < InputDimensions; ++j) { - uint64_t* w = - reinterpret_cast(const_cast(&weights[j * HalfDimensions])); + uint64_t* w = reinterpret_cast(&weights[j * HalfDimensions]); for (IndexType i = 0; i < HalfDimensions * sizeof(WeightType) / sizeof(uint64_t); i += di) order_fn(&w[i]); @@ -278,17 +279,16 @@ class FeatureTransformer { #endif } - inline void scale_weights(bool read) const { + static void scale_weights(WeightType* weights, BiasType* biases, bool read) { for (IndexType j = 0; j < InputDimensions; ++j) { - WeightType* w = const_cast(&weights[j * HalfDimensions]); + WeightType* w = &weights[j * HalfDimensions]; for (IndexType i = 0; i < HalfDimensions; ++i) w[i] = read ? w[i] * 2 : w[i] / 2; } - BiasType* b = const_cast(biases); for (IndexType i = 0; i < HalfDimensions; ++i) - b[i] = read ? b[i] * 2 : b[i] / 2; + biases[i] = read ? biases[i] * 2 : biases[i] / 2; } // Read network parameters @@ -298,23 +298,28 @@ class FeatureTransformer { read_leb_128(stream, weights, HalfDimensions * InputDimensions); read_leb_128(stream, psqtWeights, PSQTBuckets * InputDimensions); - permute_weights(inverse_order_packs); - scale_weights(true); + permute_weights(weights, biases, inverse_order_packs); + scale_weights(weights, biases, true); return !stream.fail(); } // Write network parameters bool write_parameters(std::ostream& stream) const { + std::unique_ptr biasesToWrite{new BiasType[HalfDimensions]}; + std::unique_ptr weightsToWrite{ + new WeightType[HalfDimensions * InputDimensions]}; - permute_weights(order_packs); - scale_weights(false); + std::memcpy(biasesToWrite.get(), biases, sizeof(BiasType) * HalfDimensions); + std::memcpy(weightsToWrite.get(), weights, + sizeof(WeightType) * HalfDimensions * InputDimensions); - write_leb_128(stream, biases, HalfDimensions); - write_leb_128(stream, weights, HalfDimensions * InputDimensions); + permute_weights(weightsToWrite.get(), biasesToWrite.get(), order_packs); + scale_weights(weightsToWrite.get(), biasesToWrite.get(), false); + + write_leb_128(stream, biasesToWrite.get(), HalfDimensions); + write_leb_128(stream, weightsToWrite.get(), HalfDimensions * InputDimensions); write_leb_128(stream, psqtWeights, PSQTBuckets * InputDimensions); - permute_weights(inverse_order_packs); - scale_weights(true); return !stream.fail(); }