From 8c6d3aff30888d5af311d0b979fc7a4ff720f61b Mon Sep 17 00:00:00 2001 From: Shawn Xu Date: Fri, 27 Dec 2024 18:33:27 -0800 Subject: [PATCH] remove use of const cast in nnue_feature_transformer.h no functional change --- src/nnue/nnue_feature_transformer.h | 35 ++++++++++++++++------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index fa180678d89..83646632889 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -256,21 +256,22 @@ class FeatureTransformer { #endif } - void permute_weights([[maybe_unused]] void (*order_fn)(uint64_t*)) const { + static void permute_weights([[maybe_unused]] WeightType* weights, + [[maybe_unused]] BiasType* biases, + [[maybe_unused]] void (*order_fn)(uint64_t*)) { #if defined(USE_AVX2) #if defined(USE_AVX512) constexpr IndexType di = 16; #else constexpr IndexType di = 8; #endif - uint64_t* b = reinterpret_cast(const_cast(&biases[0])); + uint64_t* b = reinterpret_cast(&biases[0]); for (IndexType i = 0; i < HalfDimensions * sizeof(BiasType) / sizeof(uint64_t); i += di) order_fn(&b[i]); for (IndexType j = 0; j < InputDimensions; ++j) { - uint64_t* w = - reinterpret_cast(const_cast(&weights[j * HalfDimensions])); + uint64_t* w = reinterpret_cast(&weights[j * HalfDimensions]); for (IndexType i = 0; i < HalfDimensions * sizeof(WeightType) / sizeof(uint64_t); i += di) order_fn(&w[i]); @@ -278,17 +279,16 @@ class FeatureTransformer { #endif } - inline void scale_weights(bool read) const { + static void scale_weights(WeightType* weights, BiasType* biases, bool read) { for (IndexType j = 0; j < InputDimensions; ++j) { - WeightType* w = const_cast(&weights[j * HalfDimensions]); + WeightType* w = &weights[j * HalfDimensions]; for (IndexType i = 0; i < HalfDimensions; ++i) w[i] = read ? w[i] * 2 : w[i] / 2; } - BiasType* b = const_cast(biases); for (IndexType i = 0; i < HalfDimensions; ++i) - b[i] = read ? b[i] * 2 : b[i] / 2; + biases[i] = read ? biases[i] * 2 : biases[i] / 2; } // Read network parameters @@ -298,23 +298,26 @@ class FeatureTransformer { read_leb_128(stream, weights, HalfDimensions * InputDimensions); read_leb_128(stream, psqtWeights, PSQTBuckets * InputDimensions); - permute_weights(inverse_order_packs); - scale_weights(true); + permute_weights(weights, biases, inverse_order_packs); + scale_weights(weights, biases, true); return !stream.fail(); } // Write network parameters bool write_parameters(std::ostream& stream) const { + BiasType* biasesToWrite = new BiasType[HalfDimensions]; + WeightType* weightsToWrite = new WeightType[HalfDimensions * InputDimensions]; - permute_weights(order_packs); - scale_weights(false); + std::memcpy(biasesToWrite, biases, sizeof(BiasType) * HalfDimensions); + std::memcpy(weightsToWrite, weights, sizeof(WeightType) * HalfDimensions * InputDimensions); - write_leb_128(stream, biases, HalfDimensions); - write_leb_128(stream, weights, HalfDimensions * InputDimensions); + permute_weights(weightsToWrite, biasesToWrite, order_packs); + scale_weights(weightsToWrite, biasesToWrite, false); + + write_leb_128(stream, biasesToWrite, HalfDimensions); + write_leb_128(stream, weightsToWrite, HalfDimensions * InputDimensions); write_leb_128(stream, psqtWeights, PSQTBuckets * InputDimensions); - permute_weights(inverse_order_packs); - scale_weights(true); return !stream.fail(); }