From b4216eaa82bcfcd8bab9fec7dbd1b2d614cdf8e6 Mon Sep 17 00:00:00 2001 From: QueensGambit Date: Sun, 6 Oct 2024 13:07:42 +0200 Subject: [PATCH] Add select_nn_index() for phase selection (#216) - return 0 if no phases is enabled Set Game_Phase_Definition default to "lichess" --- engine/src/searchthread.cpp | 36 +++++++++++++++++++++-------------- engine/src/searchthread.h | 7 +++++++ engine/src/uci/optionsuci.cpp | 2 +- 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/engine/src/searchthread.cpp b/engine/src/searchthread.cpp index b61ad7b5..1ff7deec 100644 --- a/engine/src/searchthread.cpp +++ b/engine/src/searchthread.cpp @@ -379,27 +379,35 @@ void SearchThread::create_mini_batch() } } +size_t SearchThread::select_nn_index() +{ + if (nets.size() == 1) { + return 0; + } + // determine majority class in current batch + using pair_type = decltype(phaseCountMap)::value_type; + auto pr = std::max_element + ( + std::begin(phaseCountMap), std::end(phaseCountMap), + [](const pair_type& p1, const pair_type& p2) { + return p1.second < p2.second; + } + ); + + GamePhase majorityPhase = pr->first; + + phaseCountMap.clear(); + return phaseToNetsIndex.at(majorityPhase); +} + void SearchThread::thread_iteration() { create_mini_batch(); #ifndef SEARCH_UCT if (newNodes->size() != 0) { - - // determine majority class in current batch - using pair_type = decltype(phaseCountMap)::value_type; - auto pr = std::max_element - ( - std::begin(phaseCountMap), std::end(phaseCountMap), - [](const pair_type& p1, const pair_type& p2) { - return p1.second < p2.second; - } - ); - - GamePhase majorityPhase = pr->first; - phaseCountMap.clear(); // query the network that corresponds to the majority phase - nets[phaseToNetsIndex.at(majorityPhase)]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs); + nets[select_nn_index()]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs); set_nn_results_to_child_nodes(); } #endif diff --git a/engine/src/searchthread.h b/engine/src/searchthread.h index 992321b0..55beb265 100644 --- a/engine/src/searchthread.h +++ b/engine/src/searchthread.h @@ -196,6 +196,13 @@ class SearchThread : public NeuralNetAPIUser * @return Q-Value converted to double */ double get_current_transposition_q_value(const Node* currentNode, ChildIdx childIdx, uint_fast32_t transposVisits); + + /** + * @brief select_nn_index Returns the index according to the majority phase in the current batch. + * If no phases is enabled, 0 will be returned. + * @return Majority phase index or 0 + */ + size_t select_nn_index(); }; void run_search_thread(SearchThread *t); diff --git a/engine/src/uci/optionsuci.cpp b/engine/src/uci/optionsuci.cpp index dce9b3e5..e3e71ae4 100644 --- a/engine/src/uci/optionsuci.cpp +++ b/engine/src/uci/optionsuci.cpp @@ -193,7 +193,7 @@ void OptionsUCI::init(OptionsMap &o) o["Use_Raw_Network"] << Option(false); o["Virtual_Style"] << Option("virtual_mix", { "virtual_loss", "virtual_visit", "virtual_offset", "virtual_mix" }); o["Virtual_Mix_Threshold"] << Option(1000, 1, 99999999); - o["Game_Phase_Definition"] << Option("movecount", { "lichess", "movecount"}); + o["Game_Phase_Definition"] << Option("lichess", { "lichess", "movecount"}); // additional UCI-Options for RL only #ifdef USE_RL o["Centi_Node_Random_Factor"] << Option(10, 0, 100);