Skip to content

Commit

Permalink
Cleanup for NNUE arch selection. All three supported archs are now us…
Browse files Browse the repository at this point in the history
…able in same binary. (#380)

Switch to new magic number for V5-512.
  • Loading branch information
Matthies authored Nov 20, 2022
1 parent 7d61820 commit 25bf728
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 72 deletions.
7 changes: 4 additions & 3 deletions src/RubiChess.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#pragma once

#define VERNUMLEGACY 2022
#define NNUEDEFAULT nn-a2bb5af869-20221001.nnue
#define NNUEDEFAULT nn-5e6b321f90-20221001.nnue

// enable this switch for faster SSE2 code using 16bit integers
#define FASTSSE2
Expand Down Expand Up @@ -692,7 +692,8 @@ enum NnueType { NnueDisabled = 0, NnueArchV1, NnueArchV5 };
// The following constants were introduced in original NNUE port from Shogi
#define NNUEFILEVERSIONROTATE 0x7AF32F16u
#define NNUEFILEVERSIONNOBPZ 0x7AF32F17u
#define NNUEFILEVERSIONSFNNv5 0x7af32f20u
#define NNUEFILEVERSIONSFNNv5_1024 0x7af32f20u
#define NNUEFILEVERSIONSFNNv5_512 0x7af32f30u
#define NNUENETLAYERHASH 0xCC03DAE4u
#define NNUECLIPPEDRELUHASH 0x538D24C7u
#define NNUEFEATUREHASH_HalfKP 0x5D69D5B8u
Expand Down Expand Up @@ -1656,7 +1657,7 @@ class chessposition
template <NnueType Nt, Color c> void HalfkpAppendChangedIndices(DirtyPiece* dp, NnueIndexList *add, NnueIndexList *remove);
template <NnueType Nt, Color c, unsigned int NnueFtHalfdims, unsigned int NnuePsqtBuckets> void UpdateAccumulator();
template <NnueType Nt, unsigned int NnueFtHalfdims, unsigned int NnuePsqtBuckets> int Transform(clipped_t *output, int bucket = 0);
template <NnueType Nt> int NnueGetEval();
int NnueGetEval();
#ifdef NNUELEARN
void toSfen(PackedSfen *sfen);
int getFromSfen(PackedSfen* sfen);
Expand Down
5 changes: 1 addition & 4 deletions src/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -870,10 +870,7 @@ int chessposition::getEval()
if (NnueReady && abs(GETEGVAL(psqval)) < NnuePsqThreshold)
{
int frcCorrection = (en.chess960 ? getFrcCorrection() : 0);
if (NnueReady == NnueArchV1)
score = NnueGetEval<NnueArchV1>();
else
score = NnueGetEval<NnueArchV5>();
score = NnueGetEval();
score += S2MSIGN(state & S2MMASK) * contempt;
int phscaled = score * (116 + phcount) / 128;

Expand Down
156 changes: 91 additions & 65 deletions src/nnue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,22 @@ class NnueArchitecture
virtual bool ReadWeights(NnueNetsource* nr, uint32_t nethash) = 0;
virtual void WriteFeatureWeights(NnueNetsource* nr, bool bpz) = 0;
virtual void WriteWeights(NnueNetsource* nr, uint32_t nethash) = 0;
virtual void rescaleLastLayer(int ratio64) = 0;
virtual string getArchDescription() = 0;
virtual void RescaleLastLayer(int ratio64) = 0;
virtual string GetArchName() = 0;
virtual string GetArchDescription() = 0;
virtual uint32_t GetFtHash() = 0;
virtual uint32_t GetHash() = 0;
virtual int getEval(chessposition* pos) = 0;
virtual int GetEval(chessposition* pos) = 0;
virtual int16_t* GetFeatureWeight() = 0;
virtual int16_t* GetFeatureBias() = 0;
virtual int32_t* GetFeaturePsqtWeight() = 0;
virtual uint32_t GetFileVersion() = 0;
};


NnueArchitecture* NnueCurrentArch;


// The network architecture V1
class NnueArchitectureV1 : public NnueArchitecture {
public:
Expand Down Expand Up @@ -138,15 +146,18 @@ class NnueArchitectureV1 : public NnueArchitecture {
nr->write((unsigned char*)&nethash, sizeof(uint32_t));
LayerStack[0].NnueOut.WriteWeights(nr);
}
void rescaleLastLayer(int ratio64) {
void RescaleLastLayer(int ratio64) {
LayerStack[0].NnueOut.bias[0] = (int32_t)round(LayerStack[0].NnueOut.bias[0] * ratio64 / NnueValueScale);
for (unsigned int i = 0; i < NnueHidden2Dims; i++)
LayerStack[0].NnueOut.weight[i] = (int32_t)round(LayerStack[0].NnueOut.weight[i] * ratio64 / NnueValueScale);
}
string getArchDescription() {
string GetArchName() {
return "V1";
}
string GetArchDescription() {
return "Features=HalfKP(Friend)[40960->256x2],Network=AffineTransform[1<-32](ClippedReLU[32](AffineTransform[32<-32](ClippedReLU[32](AffineTransform[32<-512](InputSlice[512(0:512)])))))";
}
int getEval(chessposition *pos) {
int GetEval(chessposition *pos) {
struct NnueNetwork {
alignas(64) clipped_t input[NnueFtOutputdims];
alignas(64) int32_t hidden1_values[NnueHidden1Dims];
Expand All @@ -165,12 +176,23 @@ class NnueArchitectureV1 : public NnueArchitecture {

return network.out_value * NnueValueScale / 1024;
}
} NnueV1;

int16_t* GetFeatureWeight() {
return NnueFt.weight;
}
int16_t* GetFeatureBias() {
return NnueFt.bias;
}
int32_t* GetFeaturePsqtWeight() {
return nullptr;
}
uint32_t GetFileVersion() {
return NNUEFILEVERSIONNOBPZ; // always write networks without BPZ
}
};

template <unsigned int NnueFtOutputdims>
class NnueArchitectureV5 : public NnueArchitecture {
public:
static constexpr unsigned int NnueFtOutputdims = 512;
static_assert(NnueFtOutputdims <= MAXINPUTLAYER, "Accumulator not big enough");
static constexpr unsigned int NnueFtHalfdims = NnueFtOutputdims;
static constexpr unsigned int NnueFtInputdims = 64 * 11 * 64 / 2;
Expand Down Expand Up @@ -227,17 +249,20 @@ class NnueArchitectureV5 : public NnueArchitecture {
LayerStack[i].NnueOut.WriteWeights(nr);
}
}
void rescaleLastLayer(int ratio64) {
void RescaleLastLayer(int ratio64) {
for (unsigned int b = 0; b < NnueLayerStacks; b++) {
LayerStack[b].NnueOut.bias[0] = (int32_t)round(LayerStack[b].NnueOut.bias[0] * ratio64 / NnueValueScale);
for (unsigned int i = 0; i < NnueHidden2Dims; i++)
LayerStack[b].NnueOut.weight[i] = (int32_t)round(LayerStack[b].NnueOut.weight[i] * ratio64 / NnueValueScale);
}
}
string getArchDescription() {
return "HalfKAv2_hm, 512x16+16x32x1";
string GetArchName() {
return "V5-" + to_string(NnueFtOutputdims);
}
string GetArchDescription() {
return "HalfKAv2_hm, " + to_string(NnueFtOutputdims) + "x16+16x32x1";
}
int getEval(chessposition* pos) {
int GetEval(chessposition* pos) {
struct NnueNetwork {
alignas(64) clipped_t input[NnueFtOutputdims];
alignas(64)int32_t hidden1_values[NnueHidden1Dims];
Expand All @@ -264,23 +289,20 @@ class NnueArchitectureV5 : public NnueArchitecture {

return (psqt + positional) * NnueValueScale / 1024;
}
} NnueV5;


template<NnueType Nt>
class NnueArchInterface {
public:
constexpr static int16_t* getFeatureWeight() {
return (Nt == NnueArchV1 ? NnueV1.NnueFt.weight : NnueV5.NnueFt.weight);
int16_t* GetFeatureWeight() {
return NnueFt.weight;
}
constexpr static int16_t* getFeatureBias() {
return (Nt == NnueArchV1 ? NnueV1.NnueFt.bias : NnueV5.NnueFt.bias);
int16_t* GetFeatureBias() {
return NnueFt.bias;
}
constexpr static int32_t* getFeaturePsqtWeight() {
return (Nt == NnueArchV1 ? nullptr : NnueV5.NnueFt.psqtWeights);
int32_t* GetFeaturePsqtWeight() {
return NnueFt.psqtWeights;
}
constexpr static int getEval(chessposition* pos) {
return (Nt == NnueArchV1 ? NnueV1.getEval(pos) : NnueV5.getEval(pos));
uint32_t GetFileVersion() {
if (NnueFtOutputdims == 512)
return NNUEFILEVERSIONSFNNv5_512;
if (NnueFtOutputdims == 1024)
return NNUEFILEVERSIONSFNNv5_1024;
}
};

Expand Down Expand Up @@ -497,10 +519,10 @@ typedef int16_t ft_vec_t;

template <NnueType Nt, Color c, unsigned int NnueFtHalfdims, unsigned int NnuePsqtBuckets> void chessposition::UpdateAccumulator()
{
NnueArchInterface<Nt> NnueIf;
constexpr int16_t* weight = NnueIf.getFeatureWeight();
constexpr int16_t* bias = NnueIf.getFeatureBias();
constexpr int32_t* psqtweight = NnueIf.getFeaturePsqtWeight();
int16_t* weight = NnueCurrentArch->GetFeatureWeight();
int16_t* bias = NnueCurrentArch->GetFeatureBias();
int32_t* psqtweight = NnueCurrentArch->GetFeaturePsqtWeight();

constexpr unsigned int numRegs = (NUM_REGS > NnueFtHalfdims * 16 / SIMD_WIDTH ? NnueFtHalfdims * 16 / SIMD_WIDTH : NUM_REGS);
constexpr unsigned int tileHeight = numRegs * SIMD_WIDTH / 16;

Expand Down Expand Up @@ -835,10 +857,9 @@ int chessposition::Transform(clipped_t *output, int bucket)



template <NnueType Nt> int chessposition::NnueGetEval()
int chessposition::NnueGetEval()
{
NnueArchInterface<Nt> NnueIf;
return NnueIf.getEval(this);
return NnueCurrentArch->GetEval(this);
}


Expand Down Expand Up @@ -1308,15 +1329,21 @@ void NnueSqrClippedRelu<dims>::Propagate(int32_t* input, clipped_t* output)
//
void NnueInit()
{
NnueCurrentArch = nullptr;
}

void NnueRemove()
{
if (NnueCurrentArch) {
freealigned64(NnueCurrentArch);
NnueCurrentArch = nullptr;
}
}

bool NnueReadNet(NnueNetsource* nr)
{
NnueReady = NnueDisabled;
NnueRemove();

uint32_t version, hash, size;
string sarchitecture;
Expand All @@ -1332,43 +1359,56 @@ bool NnueReadNet(NnueNetsource* nr)

NnueType nt;
bool bpz;
char* buffer;
switch (version) {
case NNUEFILEVERSIONROTATE:
bpz = true;
nt = NnueArchV1;
buffer = (char*)allocalign64(sizeof(NnueArchitectureV1));
NnueCurrentArch = new(buffer) NnueArchitectureV1;
break;
case NNUEFILEVERSIONNOBPZ:
bpz = false;
nt = NnueArchV1;
buffer = (char*)allocalign64(sizeof(NnueArchitectureV1));
NnueCurrentArch = new(buffer) NnueArchitectureV1;
break;
case NNUEFILEVERSIONSFNNv5:
case NNUEFILEVERSIONSFNNv5_512:
bpz = false;
nt = NnueArchV5;
buffer = (char*)allocalign64(sizeof(NnueArchitectureV5<512>));
NnueCurrentArch = new(buffer) NnueArchitectureV5<512>;
break;
case NNUEFILEVERSIONSFNNv5_1024:
bpz = false;
nt = NnueArchV5;
buffer = (char*)allocalign64(sizeof(NnueArchitectureV5<1024>));
NnueCurrentArch = new(buffer) NnueArchitectureV5<1024>;
break;
default:
return false;
}

NnueArchitecture* filesArch = (nt == NnueArchV1 ? (NnueArchitecture*) & NnueV1 : (NnueArchitecture*)&NnueV5);

if (!filesArch)
if (!NnueCurrentArch)
return false;

uint32_t fthash = filesArch->GetFtHash();
uint32_t nethash = filesArch->GetHash();
uint32_t fthash = NnueCurrentArch->GetFtHash();
uint32_t nethash = NnueCurrentArch->GetHash();
uint32_t filehash = (fthash ^ nethash);

if (hash != filehash)
if (hash != filehash) {
NnueRemove();
return false;
}

// Read the weights of the feature transformer
if (!nr->read((unsigned char*)&hash, sizeof(uint32_t)) || hash != fthash)
return false;
if (!filesArch->ReadFeatureWeights(nr, bpz))
if (!NnueCurrentArch->ReadFeatureWeights(nr, bpz))
return false;

// Read the weights of the network layers recursively
if (!filesArch->ReadWeights(nr, nethash))
if (!NnueCurrentArch->ReadWeights(nr, nethash))
return false;

if (!nr->endOfNet())
Expand Down Expand Up @@ -1425,7 +1465,6 @@ void NnueWriteNet(vector<string> args)
size_t cs = args.size();
string NnueNetPath = "export.nnue";
int rescale = 0;
bool bpz = false;
bool zExport = false;
if (ci < cs)
NnueNetPath = args[ci++];
Expand All @@ -1435,11 +1474,6 @@ void NnueWriteNet(vector<string> args)
{
rescale = stoi(args[ci++]);
}
if (args[ci] == "bpz")
{
bpz = true;
ci++;
}
if (args[ci] == "z")
{
zExport = true;
Expand Down Expand Up @@ -1471,17 +1505,15 @@ void NnueWriteNet(vector<string> args)
return;
}

NnueArchitecture* filesArch = (NnueReady == NnueArchV1 ? (NnueArchitecture*)&NnueV1 : (NnueArchitecture*)&NnueV5);

if (rescale)
filesArch->rescaleLastLayer(rescale);
NnueCurrentArch->RescaleLastLayer(rescale);

uint32_t fthash = filesArch->GetFtHash();
uint32_t nethash = filesArch->GetHash();
uint32_t fthash = NnueCurrentArch->GetFtHash();
uint32_t nethash = NnueCurrentArch->GetHash();
uint32_t filehash = (fthash ^ nethash);

uint32_t version = (NnueReady == NnueArchV5 ? NNUEFILEVERSIONSFNNv5 : bpz ? NNUEFILEVERSIONROTATE : NNUEFILEVERSIONNOBPZ);
string sarchitecture = filesArch->getArchDescription();
uint32_t version = NnueCurrentArch->GetFileVersion();
string sarchitecture = NnueCurrentArch->GetArchDescription();
uint32_t size = (uint32_t)sarchitecture.size();

nr.write((unsigned char*)&version, sizeof(uint32_t));
Expand All @@ -1490,8 +1522,8 @@ void NnueWriteNet(vector<string> args)
nr.write((unsigned char*)&sarchitecture[0], size);
nr.write((unsigned char*)&fthash, sizeof(uint32_t));

filesArch->WriteFeatureWeights(&nr, bpz);
filesArch->WriteWeights(&nr, nethash);
NnueCurrentArch->WriteFeatureWeights(&nr, false);
NnueCurrentArch->WriteWeights(&nr, nethash);

size_t insize = nr.next - nr.readbuffer;

Expand Down Expand Up @@ -1601,7 +1633,7 @@ bool NnueNetsource::open()
if (!openOk)
guiCom << "info string The network " + en.GetNnueNetPath() + " seems corrupted or format is not supported.\n";
else
guiCom << "info string Reading network " + en.GetNnueNetPath() + " successful. Using NNUE evaluation(" + (NnueReady == NnueArchV1 ? "V1" : "V5") + ").\n";
guiCom << "info string Reading network " + en.GetNnueNetPath() + " successful. Using NNUE (" + NnueCurrentArch->GetArchName() + ").\n";

cleanup:
#ifndef NNUEINCLUDED
Expand Down Expand Up @@ -1636,9 +1668,3 @@ bool NnueNetsource::endOfNet()
{
return (next == readbuffer + readbuffersize);
}


// Explicit template instantiation
// This avoids putting these definitions in header file
template int chessposition::NnueGetEval<NnueArchV1>();
template int chessposition::NnueGetEval<NnueArchV5>();

0 comments on commit 25bf728

Please sign in to comment.