From d98989830108f3d1499e1a52670728d7c8580f7c Mon Sep 17 00:00:00 2001 From: Disservin Date: Sat, 16 Sep 2023 16:55:37 +0200 Subject: [PATCH] add clang-format This introduces clang-format to enforce a code style for stockfish. The lack of an automated code formatter has been a long standing issue #3608 This PR includes a Makefile target to format the code accordingly. --- .clang-format | 37 + src/Makefile | 3 + src/benchmark.cpp | 255 +- src/benchmark.h | 8 +- src/bitboard.cpp | 298 +- src/bitboard.h | 468 ++- src/evaluate.cpp | 281 +- src/evaluate.h | 36 +- src/main.cpp | 24 +- src/misc.cpp | 1146 +++--- src/misc.h | 209 +- src/movegen.cpp | 373 +- src/movegen.h | 94 +- src/movepick.cpp | 574 ++- src/movepick.h | 253 +- src/nnue/evaluate_nnue.cpp | 614 ++- src/nnue/evaluate_nnue.h | 56 +- src/nnue/features/half_ka_v2_hm.cpp | 91 +- src/nnue/features/half_ka_v2_hm.h | 206 +- src/nnue/layers/affine_transform.h | 532 ++- .../layers/affine_transform_sparse_input.h | 398 +- src/nnue/layers/clipped_relu.h | 289 +- src/nnue/layers/simd.h | 242 +- src/nnue/layers/sqr_clipped_relu.h | 141 +- src/nnue/nnue_accumulator.h | 14 +- src/nnue/nnue_architecture.h | 175 +- src/nnue/nnue_common.h | 460 ++- src/nnue/nnue_feature_transformer.h | 1086 +++--- src/position.cpp | 1954 +++++----- src/position.h | 689 ++-- src/search.cpp | 3309 ++++++++--------- src/search.h | 130 +- src/syzygy/tbprobe.cpp | 2714 +++++++------- src/syzygy/tbprobe.h | 54 +- src/thread.cpp | 321 +- src/thread.h | 207 +- src/thread_win32_osx.h | 52 +- src/timeman.cpp | 158 +- src/timeman.h | 37 +- src/tt.cpp | 188 +- src/tt.h | 166 +- src/tune.cpp | 99 +- src/tune.h | 243 +- src/types.h | 758 ++-- src/uci.cpp | 589 +-- src/uci.h | 98 +- src/ucioption.cpp | 230 +- 47 files changed, 9877 insertions(+), 10482 deletions(-) create mode 100644 .clang-format diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000000..e987768aed8 --- /dev/null +++ b/.clang-format @@ -0,0 +1,37 @@ +BasedOnStyle: WebKit +AccessModifierOffset: -1 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: Consecutive +AlignConsecutiveDeclarations: Consecutive +AlignEscapedNewlines: DontAlign +AlignOperands: AlignAfterOperator +AlignTrailingComments: true +AllowShortEnumsOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: true +AllowShortIfStatementsOnASingleLine: WithoutElse +AllowShortLoopsOnASingleLine: true +AlwaysBreakTemplateDeclarations: No +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: false +BreakConstructorInitializers: BeforeColon +ColumnLimit: 100 +ContinuationIndentWidth: 2 +Cpp11BracedListStyle: true +IndentGotoLabels: false +IndentPPDirectives: BeforeHash +IndentWidth: 4 +MaxEmptyLinesToKeep: 2 +NamespaceIndentation: All +ReflowComments: false +SortIncludes: false +SortUsingDeclarations: false +SpaceAfterCStyleCast: true +SpaceAfterTemplateKeyword: false +SpaceBeforeCaseColon: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeInheritanceColon: false +SpaceInEmptyBlock: false +SpacesBeforeTrailingComments: 2 +BitFieldColonSpacing: After +BreakStringLiterals: false \ No newline at end of file diff --git a/src/Makefile b/src/Makefile index f5a420b7ce0..232fac21204 100644 --- a/src/Makefile +++ b/src/Makefile @@ -929,6 +929,9 @@ net: netvariables fi; \ fi; \ +format: + find . -type f \( -iname "*.h" -o -iname "*.cpp" \) ! -name "incbin.h" | xargs clang-format -i -style=file + # default target default: help diff --git a/src/benchmark.cpp b/src/benchmark.cpp index 8e28184a3cd..5b3ce999083 100644 --- a/src/benchmark.cpp +++ b/src/benchmark.cpp @@ -27,138 +27,129 @@ namespace { -const std::vector Defaults = { - "setoption name UCI_Chess960 value false", - "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", - "r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 10", - "8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 11", - "4rrk1/pp1n3p/3q2pQ/2p1pb2/2PP4/2P3N1/P2B2PP/4RRK1 b - - 7 19", - "rq3rk1/ppp2ppp/1bnpb3/3N2B1/3NP3/7P/PPPQ1PP1/2KR3R w - - 7 14 moves d4e6", - "r1bq1r1k/1pp1n1pp/1p1p4/4p2Q/4Pp2/1BNP4/PPP2PPP/3R1RK1 w - - 2 14 moves g2g4", - "r3r1k1/2p2ppp/p1p1bn2/8/1q2P3/2NPQN2/PPP3PP/R4RK1 b - - 2 15", - "r1bbk1nr/pp3p1p/2n5/1N4p1/2Np1B2/8/PPP2PPP/2KR1B1R w kq - 0 13", - "r1bq1rk1/ppp1nppp/4n3/3p3Q/3P4/1BP1B3/PP1N2PP/R4RK1 w - - 1 16", - "4r1k1/r1q2ppp/ppp2n2/4P3/5Rb1/1N1BQ3/PPP3PP/R5K1 w - - 1 17", - "2rqkb1r/ppp2p2/2npb1p1/1N1Nn2p/2P1PP2/8/PP2B1PP/R1BQK2R b KQ - 0 11", - "r1bq1r1k/b1p1npp1/p2p3p/1p6/3PP3/1B2NN2/PP3PPP/R2Q1RK1 w - - 1 16", - "3r1rk1/p5pp/bpp1pp2/8/q1PP1P2/b3P3/P2NQRPP/1R2B1K1 b - - 6 22", - "r1q2rk1/2p1bppp/2Pp4/p6b/Q1PNp3/4B3/PP1R1PPP/2K4R w - - 2 18", - "4k2r/1pb2ppp/1p2p3/1R1p4/3P4/2r1PN2/P4PPP/1R4K1 b - - 3 22", - "3q2k1/pb3p1p/4pbp1/2r5/PpN2N2/1P2P2P/5PP1/Q2R2K1 b - - 4 26", - "6k1/6p1/6Pp/ppp5/3pn2P/1P3K2/1PP2P2/3N4 b - - 0 1", - "3b4/5kp1/1p1p1p1p/pP1PpP1P/P1P1P3/3KN3/8/8 w - - 0 1", - "2K5/p7/7P/5pR1/8/5k2/r7/8 w - - 0 1 moves g5g6 f3e3 g6g5 e3f3", - "8/6pk/1p6/8/PP3p1p/5P2/4KP1q/3Q4 w - - 0 1", - "7k/3p2pp/4q3/8/4Q3/5Kp1/P6b/8 w - - 0 1", - "8/2p5/8/2kPKp1p/2p4P/2P5/3P4/8 w - - 0 1", - "8/1p3pp1/7p/5P1P/2k3P1/8/2K2P2/8 w - - 0 1", - "8/pp2r1k1/2p1p3/3pP2p/1P1P1P1P/P5KR/8/8 w - - 0 1", - "8/3p4/p1bk3p/Pp6/1Kp1PpPp/2P2P1P/2P5/5B2 b - - 0 1", - "5k2/7R/4P2p/5K2/p1r2P1p/8/8/8 b - - 0 1", - "6k1/6p1/P6p/r1N5/5p2/7P/1b3PP1/4R1K1 w - - 0 1", - "1r3k2/4q3/2Pp3b/3Bp3/2Q2p2/1p1P2P1/1P2KP2/3N4 w - - 0 1", - "6k1/4pp1p/3p2p1/P1pPb3/R7/1r2P1PP/3B1P2/6K1 w - - 0 1", - "8/3p3B/5p2/5P2/p7/PP5b/k7/6K1 w - - 0 1", - "5rk1/q6p/2p3bR/1pPp1rP1/1P1Pp3/P3B1Q1/1K3P2/R7 w - - 93 90", - "4rrk1/1p1nq3/p7/2p1P1pp/3P2bp/3Q1Bn1/PPPB4/1K2R1NR w - - 40 21", - "r3k2r/3nnpbp/q2pp1p1/p7/Pp1PPPP1/4BNN1/1P5P/R2Q1RK1 w kq - 0 16", - "3Qb1k1/1r2ppb1/pN1n2q1/Pp1Pp1Pr/4P2p/4BP2/4B1R1/1R5K b - - 11 40", - "4k3/3q1r2/1N2r1b1/3ppN2/2nPP3/1B1R2n1/2R1Q3/3K4 w - - 5 1", - - // 5-man positions - "8/8/8/8/5kp1/P7/8/1K1N4 w - - 0 1", // Kc2 - mate - "8/8/8/5N2/8/p7/8/2NK3k w - - 0 1", // Na2 - mate - "8/3k4/8/8/8/4B3/4KB2/2B5 w - - 0 1", // draw - - // 6-man positions - "8/8/1P6/5pr1/8/4R3/7k/2K5 w - - 0 1", // Re5 - mate - "8/2p4P/8/kr6/6R1/8/8/1K6 w - - 0 1", // Ka2 - mate - "8/8/3P3k/8/1p6/8/1P6/1K3n2 b - - 0 1", // Nd2 - draw - - // 7-man positions - "8/R7/2q5/8/6k1/8/1P5p/K6R w - - 0 124", // Draw - - // Mate and stalemate positions - "6k1/3b3r/1p1p4/p1n2p2/1PPNpP1q/P3Q1p1/1R1RB1P1/5K2 b - - 0 1", - "r2r1n2/pp2bk2/2p1p2p/3q4/3PN1QP/2P3R1/P4PP1/5RK1 w - - 0 1", - "8/8/8/8/8/6k1/6p1/6K1 w - -", - "7k/7P/6K1/8/3B4/8/8/8 b - -", - - // Chess 960 - "setoption name UCI_Chess960 value true", - "bbqnnrkr/pppppppp/8/8/8/8/PPPPPPPP/BBQNNRKR w HFhf - 0 1 moves g2g3 d7d5 d2d4 c8h3 c1g5 e8d6 g5e7 f7f6", - "nqbnrkrb/pppppppp/8/8/8/8/PPPPPPPP/NQBNRKRB w KQkq - 0 1", - "setoption name UCI_Chess960 value false" -}; - -} // namespace + const std::vector Defaults = { + "setoption name UCI_Chess960 value false", + "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", + "r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 10", + "8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 11", + "4rrk1/pp1n3p/3q2pQ/2p1pb2/2PP4/2P3N1/P2B2PP/4RRK1 b - - 7 19", + "rq3rk1/ppp2ppp/1bnpb3/3N2B1/3NP3/7P/PPPQ1PP1/2KR3R w - - 7 14 moves d4e6", + "r1bq1r1k/1pp1n1pp/1p1p4/4p2Q/4Pp2/1BNP4/PPP2PPP/3R1RK1 w - - 2 14 moves g2g4", + "r3r1k1/2p2ppp/p1p1bn2/8/1q2P3/2NPQN2/PPP3PP/R4RK1 b - - 2 15", + "r1bbk1nr/pp3p1p/2n5/1N4p1/2Np1B2/8/PPP2PPP/2KR1B1R w kq - 0 13", + "r1bq1rk1/ppp1nppp/4n3/3p3Q/3P4/1BP1B3/PP1N2PP/R4RK1 w - - 1 16", + "4r1k1/r1q2ppp/ppp2n2/4P3/5Rb1/1N1BQ3/PPP3PP/R5K1 w - - 1 17", + "2rqkb1r/ppp2p2/2npb1p1/1N1Nn2p/2P1PP2/8/PP2B1PP/R1BQK2R b KQ - 0 11", + "r1bq1r1k/b1p1npp1/p2p3p/1p6/3PP3/1B2NN2/PP3PPP/R2Q1RK1 w - - 1 16", + "3r1rk1/p5pp/bpp1pp2/8/q1PP1P2/b3P3/P2NQRPP/1R2B1K1 b - - 6 22", + "r1q2rk1/2p1bppp/2Pp4/p6b/Q1PNp3/4B3/PP1R1PPP/2K4R w - - 2 18", + "4k2r/1pb2ppp/1p2p3/1R1p4/3P4/2r1PN2/P4PPP/1R4K1 b - - 3 22", + "3q2k1/pb3p1p/4pbp1/2r5/PpN2N2/1P2P2P/5PP1/Q2R2K1 b - - 4 26", + "6k1/6p1/6Pp/ppp5/3pn2P/1P3K2/1PP2P2/3N4 b - - 0 1", + "3b4/5kp1/1p1p1p1p/pP1PpP1P/P1P1P3/3KN3/8/8 w - - 0 1", + "2K5/p7/7P/5pR1/8/5k2/r7/8 w - - 0 1 moves g5g6 f3e3 g6g5 e3f3", + "8/6pk/1p6/8/PP3p1p/5P2/4KP1q/3Q4 w - - 0 1", "7k/3p2pp/4q3/8/4Q3/5Kp1/P6b/8 w - - 0 1", + "8/2p5/8/2kPKp1p/2p4P/2P5/3P4/8 w - - 0 1", "8/1p3pp1/7p/5P1P/2k3P1/8/2K2P2/8 w - - 0 1", + "8/pp2r1k1/2p1p3/3pP2p/1P1P1P1P/P5KR/8/8 w - - 0 1", + "8/3p4/p1bk3p/Pp6/1Kp1PpPp/2P2P1P/2P5/5B2 b - - 0 1", + "5k2/7R/4P2p/5K2/p1r2P1p/8/8/8 b - - 0 1", "6k1/6p1/P6p/r1N5/5p2/7P/1b3PP1/4R1K1 w - - 0 1", + "1r3k2/4q3/2Pp3b/3Bp3/2Q2p2/1p1P2P1/1P2KP2/3N4 w - - 0 1", + "6k1/4pp1p/3p2p1/P1pPb3/R7/1r2P1PP/3B1P2/6K1 w - - 0 1", + "8/3p3B/5p2/5P2/p7/PP5b/k7/6K1 w - - 0 1", + "5rk1/q6p/2p3bR/1pPp1rP1/1P1Pp3/P3B1Q1/1K3P2/R7 w - - 93 90", + "4rrk1/1p1nq3/p7/2p1P1pp/3P2bp/3Q1Bn1/PPPB4/1K2R1NR w - - 40 21", + "r3k2r/3nnpbp/q2pp1p1/p7/Pp1PPPP1/4BNN1/1P5P/R2Q1RK1 w kq - 0 16", + "3Qb1k1/1r2ppb1/pN1n2q1/Pp1Pp1Pr/4P2p/4BP2/4B1R1/1R5K b - - 11 40", + "4k3/3q1r2/1N2r1b1/3ppN2/2nPP3/1B1R2n1/2R1Q3/3K4 w - - 5 1", + + // 5-man positions + "8/8/8/8/5kp1/P7/8/1K1N4 w - - 0 1", // Kc2 - mate + "8/8/8/5N2/8/p7/8/2NK3k w - - 0 1", // Na2 - mate + "8/3k4/8/8/8/4B3/4KB2/2B5 w - - 0 1", // draw + + // 6-man positions + "8/8/1P6/5pr1/8/4R3/7k/2K5 w - - 0 1", // Re5 - mate + "8/2p4P/8/kr6/6R1/8/8/1K6 w - - 0 1", // Ka2 - mate + "8/8/3P3k/8/1p6/8/1P6/1K3n2 b - - 0 1", // Nd2 - draw + + // 7-man positions + "8/R7/2q5/8/6k1/8/1P5p/K6R w - - 0 124", // Draw + + // Mate and stalemate positions + "6k1/3b3r/1p1p4/p1n2p2/1PPNpP1q/P3Q1p1/1R1RB1P1/5K2 b - - 0 1", + "r2r1n2/pp2bk2/2p1p2p/3q4/3PN1QP/2P3R1/P4PP1/5RK1 w - - 0 1", "8/8/8/8/8/6k1/6p1/6K1 w - -", + "7k/7P/6K1/8/3B4/8/8/8 b - -", + + // Chess 960 + "setoption name UCI_Chess960 value true", + "bbqnnrkr/pppppppp/8/8/8/8/PPPPPPPP/BBQNNRKR w HFhf - 0 1 moves g2g3 d7d5 d2d4 c8h3 c1g5 e8d6 g5e7 f7f6", + "nqbnrkrb/pppppppp/8/8/8/8/PPPPPPPP/NQBNRKRB w KQkq - 0 1", + "setoption name UCI_Chess960 value false"}; + +} // namespace namespace Stockfish { -/// setup_bench() builds a list of UCI commands to be run by bench. There -/// are five parameters: TT size in MB, number of search threads that -/// should be used, the limit value spent for each position, a file name -/// where to look for positions in FEN format, and the type of the limit: -/// depth, perft, nodes and movetime (in milliseconds). Examples: -/// -/// bench : search default positions up to depth 13 -/// bench 64 1 15 : search default positions up to depth 15 (TT = 64MB) -/// bench 64 1 100000 default nodes : search default positions for 100K nodes each -/// bench 64 4 5000 current movetime : search current position with 4 threads for 5 sec -/// bench 16 1 5 blah perft : run a perft 5 on positions in file "blah" - -std::vector setup_bench(const Position& current, std::istream& is) { - - std::vector fens, list; - std::string go, token; - - // Assign default values to missing arguments - std::string ttSize = (is >> token) ? token : "16"; - std::string threads = (is >> token) ? token : "1"; - std::string limit = (is >> token) ? token : "13"; - std::string fenFile = (is >> token) ? token : "default"; - std::string limitType = (is >> token) ? token : "depth"; - - go = limitType == "eval" ? "eval" : "go " + limitType + " " + limit; - - if (fenFile == "default") - fens = Defaults; - - else if (fenFile == "current") - fens.push_back(current.fen()); - - else - { - std::string fen; - std::ifstream file(fenFile); - - if (!file.is_open()) - { - std::cerr << "Unable to open file " << fenFile << std::endl; - exit(EXIT_FAILURE); - } - - while (getline(file, fen)) - if (!fen.empty()) - fens.push_back(fen); - - file.close(); - } - - list.emplace_back("setoption name Threads value " + threads); - list.emplace_back("setoption name Hash value " + ttSize); - list.emplace_back("ucinewgame"); - - for (const std::string& fen : fens) - if (fen.find("setoption") != std::string::npos) - list.emplace_back(fen); - else - { - list.emplace_back("position fen " + fen); - list.emplace_back(go); - } - - return list; -} - -} // namespace Stockfish + /// setup_bench() builds a list of UCI commands to be run by bench. There + /// are five parameters: TT size in MB, number of search threads that + /// should be used, the limit value spent for each position, a file name + /// where to look for positions in FEN format, and the type of the limit: + /// depth, perft, nodes and movetime (in milliseconds). Examples: + /// + /// bench : search default positions up to depth 13 + /// bench 64 1 15 : search default positions up to depth 15 (TT = 64MB) + /// bench 64 1 100000 default nodes : search default positions for 100K nodes each + /// bench 64 4 5000 current movetime : search current position with 4 threads for 5 sec + /// bench 16 1 5 blah perft : run a perft 5 on positions in file "blah" + + std::vector setup_bench(const Position& current, std::istream& is) { + + std::vector fens, list; + std::string go, token; + + // Assign default values to missing arguments + std::string ttSize = (is >> token) ? token : "16"; + std::string threads = (is >> token) ? token : "1"; + std::string limit = (is >> token) ? token : "13"; + std::string fenFile = (is >> token) ? token : "default"; + std::string limitType = (is >> token) ? token : "depth"; + + go = limitType == "eval" ? "eval" : "go " + limitType + " " + limit; + + if (fenFile == "default") + fens = Defaults; + + else if (fenFile == "current") + fens.push_back(current.fen()); + + else { + std::string fen; + std::ifstream file(fenFile); + + if (!file.is_open()) { + std::cerr << "Unable to open file " << fenFile << std::endl; + exit(EXIT_FAILURE); + } + + while (getline(file, fen)) + if (!fen.empty()) fens.push_back(fen); + + file.close(); + } + + list.emplace_back("setoption name Threads value " + threads); + list.emplace_back("setoption name Hash value " + ttSize); + list.emplace_back("ucinewgame"); + + for (const std::string& fen : fens) + if (fen.find("setoption") != std::string::npos) + list.emplace_back(fen); + else { + list.emplace_back("position fen " + fen); + list.emplace_back(go); + } + + return list; + } + +} // namespace Stockfish diff --git a/src/benchmark.h b/src/benchmark.h index 64acf833ac0..dacdbeac1d9 100644 --- a/src/benchmark.h +++ b/src/benchmark.h @@ -25,10 +25,10 @@ namespace Stockfish { -class Position; + class Position; -std::vector setup_bench(const Position&, std::istream&); + std::vector setup_bench(const Position&, std::istream&); -} // namespace Stockfish +} // namespace Stockfish -#endif // #ifndef BENCHMARK_H_INCLUDED +#endif // #ifndef BENCHMARK_H_INCLUDED diff --git a/src/bitboard.cpp b/src/bitboard.cpp index bed2b3ee309..75c326c162c 100644 --- a/src/bitboard.cpp +++ b/src/bitboard.cpp @@ -26,195 +26,181 @@ namespace Stockfish { -uint8_t PopCnt16[1 << 16]; -uint8_t SquareDistance[SQUARE_NB][SQUARE_NB]; + uint8_t PopCnt16[1 << 16]; + uint8_t SquareDistance[SQUARE_NB][SQUARE_NB]; -Bitboard LineBB[SQUARE_NB][SQUARE_NB]; -Bitboard BetweenBB[SQUARE_NB][SQUARE_NB]; -Bitboard PseudoAttacks[PIECE_TYPE_NB][SQUARE_NB]; -Bitboard PawnAttacks[COLOR_NB][SQUARE_NB]; + Bitboard LineBB[SQUARE_NB][SQUARE_NB]; + Bitboard BetweenBB[SQUARE_NB][SQUARE_NB]; + Bitboard PseudoAttacks[PIECE_TYPE_NB][SQUARE_NB]; + Bitboard PawnAttacks[COLOR_NB][SQUARE_NB]; -Magic RookMagics[SQUARE_NB]; -Magic BishopMagics[SQUARE_NB]; + Magic RookMagics[SQUARE_NB]; + Magic BishopMagics[SQUARE_NB]; -namespace { + namespace { - Bitboard RookTable[0x19000]; // To store rook attacks - Bitboard BishopTable[0x1480]; // To store bishop attacks + Bitboard RookTable[0x19000]; // To store rook attacks + Bitboard BishopTable[0x1480]; // To store bishop attacks - void init_magics(PieceType pt, Bitboard table[], Magic magics[]); + void init_magics(PieceType pt, Bitboard table[], Magic magics[]); -} + } + + /// safe_destination() returns the bitboard of target square for the given step + /// from the given square. If the step is off the board, returns empty bitboard. + + inline Bitboard safe_destination(Square s, int step) { + Square to = Square(s + step); + return is_ok(to) && distance(s, to) <= 2 ? square_bb(to) : Bitboard(0); + } -/// safe_destination() returns the bitboard of target square for the given step -/// from the given square. If the step is off the board, returns empty bitboard. -inline Bitboard safe_destination(Square s, int step) { - Square to = Square(s + step); - return is_ok(to) && distance(s, to) <= 2 ? square_bb(to) : Bitboard(0); -} + /// Bitboards::pretty() returns an ASCII representation of a bitboard suitable + /// to be printed to standard output. Useful for debugging. + std::string Bitboards::pretty(Bitboard b) { -/// Bitboards::pretty() returns an ASCII representation of a bitboard suitable -/// to be printed to standard output. Useful for debugging. + std::string s = "+---+---+---+---+---+---+---+---+\n"; -std::string Bitboards::pretty(Bitboard b) { + for (Rank r = RANK_8; r >= RANK_1; --r) { + for (File f = FILE_A; f <= FILE_H; ++f) s += b & make_square(f, r) ? "| X " : "| "; - std::string s = "+---+---+---+---+---+---+---+---+\n"; + s += "| " + std::to_string(1 + r) + "\n+---+---+---+---+---+---+---+---+\n"; + } + s += " a b c d e f g h\n"; - for (Rank r = RANK_8; r >= RANK_1; --r) - { - for (File f = FILE_A; f <= FILE_H; ++f) - s += b & make_square(f, r) ? "| X " : "| "; + return s; + } - s += "| " + std::to_string(1 + r) + "\n+---+---+---+---+---+---+---+---+\n"; - } - s += " a b c d e f g h\n"; - return s; -} + /// Bitboards::init() initializes various bitboard tables. It is called at + /// startup and relies on global objects to be already zero-initialized. + void Bitboards::init() { -/// Bitboards::init() initializes various bitboard tables. It is called at -/// startup and relies on global objects to be already zero-initialized. + for (unsigned i = 0; i < (1 << 16); ++i) PopCnt16[i] = uint8_t(std::bitset<16>(i).count()); -void Bitboards::init() { + for (Square s1 = SQ_A1; s1 <= SQ_H8; ++s1) + for (Square s2 = SQ_A1; s2 <= SQ_H8; ++s2) + SquareDistance[s1][s2] = std::max(distance(s1, s2), distance(s1, s2)); - for (unsigned i = 0; i < (1 << 16); ++i) - PopCnt16[i] = uint8_t(std::bitset<16>(i).count()); + init_magics(ROOK, RookTable, RookMagics); + init_magics(BISHOP, BishopTable, BishopMagics); - for (Square s1 = SQ_A1; s1 <= SQ_H8; ++s1) - for (Square s2 = SQ_A1; s2 <= SQ_H8; ++s2) - SquareDistance[s1][s2] = std::max(distance(s1, s2), distance(s1, s2)); + for (Square s1 = SQ_A1; s1 <= SQ_H8; ++s1) { + PawnAttacks[WHITE][s1] = pawn_attacks_bb(square_bb(s1)); + PawnAttacks[BLACK][s1] = pawn_attacks_bb(square_bb(s1)); - init_magics(ROOK, RookTable, RookMagics); - init_magics(BISHOP, BishopTable, BishopMagics); + for (int step : {-9, -8, -7, -1, 1, 7, 8, 9}) + PseudoAttacks[KING][s1] |= safe_destination(s1, step); - for (Square s1 = SQ_A1; s1 <= SQ_H8; ++s1) - { - PawnAttacks[WHITE][s1] = pawn_attacks_bb(square_bb(s1)); - PawnAttacks[BLACK][s1] = pawn_attacks_bb(square_bb(s1)); + for (int step : {-17, -15, -10, -6, 6, 10, 15, 17}) + PseudoAttacks[KNIGHT][s1] |= safe_destination(s1, step); - for (int step : {-9, -8, -7, -1, 1, 7, 8, 9} ) - PseudoAttacks[KING][s1] |= safe_destination(s1, step); + PseudoAttacks[QUEEN][s1] = PseudoAttacks[BISHOP][s1] = attacks_bb(s1, 0); + PseudoAttacks[QUEEN][s1] |= PseudoAttacks[ROOK][s1] = attacks_bb(s1, 0); - for (int step : {-17, -15, -10, -6, 6, 10, 15, 17} ) - PseudoAttacks[KNIGHT][s1] |= safe_destination(s1, step); + for (PieceType pt : {BISHOP, ROOK}) + for (Square s2 = SQ_A1; s2 <= SQ_H8; ++s2) { + if (PseudoAttacks[pt][s1] & s2) { + LineBB[s1][s2] = (attacks_bb(pt, s1, 0) & attacks_bb(pt, s2, 0)) | s1 | s2; + BetweenBB[s1][s2] = + (attacks_bb(pt, s1, square_bb(s2)) & attacks_bb(pt, s2, square_bb(s1))); + } + BetweenBB[s1][s2] |= s2; + } + } + } - PseudoAttacks[QUEEN][s1] = PseudoAttacks[BISHOP][s1] = attacks_bb(s1, 0); - PseudoAttacks[QUEEN][s1] |= PseudoAttacks[ ROOK][s1] = attacks_bb< ROOK>(s1, 0); + namespace { - for (PieceType pt : { BISHOP, ROOK }) - for (Square s2 = SQ_A1; s2 <= SQ_H8; ++s2) - { - if (PseudoAttacks[pt][s1] & s2) - { - LineBB[s1][s2] = (attacks_bb(pt, s1, 0) & attacks_bb(pt, s2, 0)) | s1 | s2; - BetweenBB[s1][s2] = (attacks_bb(pt, s1, square_bb(s2)) & attacks_bb(pt, s2, square_bb(s1))); - } - BetweenBB[s1][s2] |= s2; - } - } -} + Bitboard sliding_attack(PieceType pt, Square sq, Bitboard occupied) { -namespace { + Bitboard attacks = 0; + Direction RookDirections[4] = {NORTH, SOUTH, EAST, WEST}; + Direction BishopDirections[4] = {NORTH_EAST, SOUTH_EAST, SOUTH_WEST, NORTH_WEST}; - Bitboard sliding_attack(PieceType pt, Square sq, Bitboard occupied) { + for (Direction d : (pt == ROOK ? RookDirections : BishopDirections)) { + Square s = sq; + while (safe_destination(s, d) && !(occupied & s)) attacks |= (s += d); + } - Bitboard attacks = 0; - Direction RookDirections[4] = {NORTH, SOUTH, EAST, WEST}; - Direction BishopDirections[4] = {NORTH_EAST, SOUTH_EAST, SOUTH_WEST, NORTH_WEST}; + return attacks; + } - for (Direction d : (pt == ROOK ? RookDirections : BishopDirections)) - { - Square s = sq; - while (safe_destination(s, d) && !(occupied & s)) - attacks |= (s += d); - } - return attacks; - } - - - // init_magics() computes all rook and bishop attacks at startup. Magic - // bitboards are used to look up attacks of sliding pieces. As a reference see - // www.chessprogramming.org/Magic_Bitboards. In particular, here we use the so - // called "fancy" approach. - - void init_magics(PieceType pt, Bitboard table[], Magic magics[]) { - - // Optimal PRNG seeds to pick the correct magics in the shortest time - int seeds[][RANK_NB] = { { 8977, 44560, 54343, 38998, 5731, 95205, 104912, 17020 }, - { 728, 10316, 55013, 32803, 12281, 15100, 16645, 255 } }; - - Bitboard occupancy[4096], reference[4096], edges, b; - int epoch[4096] = {}, cnt = 0, size = 0; - - for (Square s = SQ_A1; s <= SQ_H8; ++s) - { - // Board edges are not considered in the relevant occupancies - edges = ((Rank1BB | Rank8BB) & ~rank_bb(s)) | ((FileABB | FileHBB) & ~file_bb(s)); - - // Given a square 's', the mask is the bitboard of sliding attacks from - // 's' computed on an empty board. The index must be big enough to contain - // all the attacks for each possible subset of the mask and so is 2 power - // the number of 1s of the mask. Hence we deduce the size of the shift to - // apply to the 64 or 32 bits word to get the index. - Magic& m = magics[s]; - m.mask = sliding_attack(pt, s, 0) & ~edges; - m.shift = (Is64Bit ? 64 : 32) - popcount(m.mask); - - // Set the offset for the attacks table of the square. We have individual - // table sizes for each square with "Fancy Magic Bitboards". - m.attacks = s == SQ_A1 ? table : magics[s - 1].attacks + size; - - // Use Carry-Rippler trick to enumerate all subsets of masks[s] and - // store the corresponding sliding attack bitboard in reference[]. - b = size = 0; - do { - occupancy[size] = b; - reference[size] = sliding_attack(pt, s, b); - - if (HasPext) - m.attacks[pext(b, m.mask)] = reference[size]; - - size++; - b = (b - m.mask) & m.mask; - } while (b); - - if (HasPext) - continue; - - PRNG rng(seeds[Is64Bit][rank_of(s)]); - - // Find a magic for square 's' picking up an (almost) random number - // until we find the one that passes the verification test. - for (int i = 0; i < size; ) - { - for (m.magic = 0; popcount((m.magic * m.mask) >> 56) < 6; ) - m.magic = rng.sparse_rand(); - - // A good magic must map every possible occupancy to an index that - // looks up the correct sliding attack in the attacks[s] database. - // Note that we build up the database for square 's' as a side - // effect of verifying the magic. Keep track of the attempt count - // and save it in epoch[], little speed-up trick to avoid resetting - // m.attacks[] after every failed attempt. - for (++cnt, i = 0; i < size; ++i) - { - unsigned idx = m.index(occupancy[i]); - - if (epoch[idx] < cnt) - { - epoch[idx] = cnt; - m.attacks[idx] = reference[i]; + // init_magics() computes all rook and bishop attacks at startup. Magic + // bitboards are used to look up attacks of sliding pieces. As a reference see + // www.chessprogramming.org/Magic_Bitboards. In particular, here we use the so + // called "fancy" approach. + + void init_magics(PieceType pt, Bitboard table[], Magic magics[]) { + + // Optimal PRNG seeds to pick the correct magics in the shortest time + int seeds[][RANK_NB] = {{8977, 44560, 54343, 38998, 5731, 95205, 104912, 17020}, + {728, 10316, 55013, 32803, 12281, 15100, 16645, 255}}; + + Bitboard occupancy[4096], reference[4096], edges, b; + int epoch[4096] = {}, cnt = 0, size = 0; + + for (Square s = SQ_A1; s <= SQ_H8; ++s) { + // Board edges are not considered in the relevant occupancies + edges = ((Rank1BB | Rank8BB) & ~rank_bb(s)) | ((FileABB | FileHBB) & ~file_bb(s)); + + // Given a square 's', the mask is the bitboard of sliding attacks from + // 's' computed on an empty board. The index must be big enough to contain + // all the attacks for each possible subset of the mask and so is 2 power + // the number of 1s of the mask. Hence we deduce the size of the shift to + // apply to the 64 or 32 bits word to get the index. + Magic& m = magics[s]; + m.mask = sliding_attack(pt, s, 0) & ~edges; + m.shift = (Is64Bit ? 64 : 32) - popcount(m.mask); + + // Set the offset for the attacks table of the square. We have individual + // table sizes for each square with "Fancy Magic Bitboards". + m.attacks = s == SQ_A1 ? table : magics[s - 1].attacks + size; + + // Use Carry-Rippler trick to enumerate all subsets of masks[s] and + // store the corresponding sliding attack bitboard in reference[]. + b = size = 0; + do { + occupancy[size] = b; + reference[size] = sliding_attack(pt, s, b); + + if (HasPext) m.attacks[pext(b, m.mask)] = reference[size]; + + size++; + b = (b - m.mask) & m.mask; + } while (b); + + if (HasPext) continue; + + PRNG rng(seeds[Is64Bit][rank_of(s)]); + + // Find a magic for square 's' picking up an (almost) random number + // until we find the one that passes the verification test. + for (int i = 0; i < size;) { + for (m.magic = 0; popcount((m.magic * m.mask) >> 56) < 6;) + m.magic = rng.sparse_rand(); + + // A good magic must map every possible occupancy to an index that + // looks up the correct sliding attack in the attacks[s] database. + // Note that we build up the database for square 's' as a side + // effect of verifying the magic. Keep track of the attempt count + // and save it in epoch[], little speed-up trick to avoid resetting + // m.attacks[] after every failed attempt. + for (++cnt, i = 0; i < size; ++i) { + unsigned idx = m.index(occupancy[i]); + + if (epoch[idx] < cnt) { + epoch[idx] = cnt; + m.attacks[idx] = reference[i]; + } else if (m.attacks[idx] != reference[i]) + break; + } } - else if (m.attacks[idx] != reference[i]) - break; } } } - } -} -} // namespace Stockfish +} // namespace Stockfish diff --git a/src/bitboard.h b/src/bitboard.h index c05b6e3f8cf..8116d6444bc 100644 --- a/src/bitboard.h +++ b/src/bitboard.h @@ -30,335 +30,327 @@ namespace Stockfish { -namespace Bitboards { + namespace Bitboards { -void init(); -std::string pretty(Bitboard b); + void init(); + std::string pretty(Bitboard b); -} // namespace Stockfish::Bitboards + } // namespace Stockfish::Bitboards -constexpr Bitboard FileABB = 0x0101010101010101ULL; -constexpr Bitboard FileBBB = FileABB << 1; -constexpr Bitboard FileCBB = FileABB << 2; -constexpr Bitboard FileDBB = FileABB << 3; -constexpr Bitboard FileEBB = FileABB << 4; -constexpr Bitboard FileFBB = FileABB << 5; -constexpr Bitboard FileGBB = FileABB << 6; -constexpr Bitboard FileHBB = FileABB << 7; + constexpr Bitboard FileABB = 0x0101010101010101ULL; + constexpr Bitboard FileBBB = FileABB << 1; + constexpr Bitboard FileCBB = FileABB << 2; + constexpr Bitboard FileDBB = FileABB << 3; + constexpr Bitboard FileEBB = FileABB << 4; + constexpr Bitboard FileFBB = FileABB << 5; + constexpr Bitboard FileGBB = FileABB << 6; + constexpr Bitboard FileHBB = FileABB << 7; -constexpr Bitboard Rank1BB = 0xFF; -constexpr Bitboard Rank2BB = Rank1BB << (8 * 1); -constexpr Bitboard Rank3BB = Rank1BB << (8 * 2); -constexpr Bitboard Rank4BB = Rank1BB << (8 * 3); -constexpr Bitboard Rank5BB = Rank1BB << (8 * 4); -constexpr Bitboard Rank6BB = Rank1BB << (8 * 5); -constexpr Bitboard Rank7BB = Rank1BB << (8 * 6); -constexpr Bitboard Rank8BB = Rank1BB << (8 * 7); + constexpr Bitboard Rank1BB = 0xFF; + constexpr Bitboard Rank2BB = Rank1BB << (8 * 1); + constexpr Bitboard Rank3BB = Rank1BB << (8 * 2); + constexpr Bitboard Rank4BB = Rank1BB << (8 * 3); + constexpr Bitboard Rank5BB = Rank1BB << (8 * 4); + constexpr Bitboard Rank6BB = Rank1BB << (8 * 5); + constexpr Bitboard Rank7BB = Rank1BB << (8 * 6); + constexpr Bitboard Rank8BB = Rank1BB << (8 * 7); -extern uint8_t PopCnt16[1 << 16]; -extern uint8_t SquareDistance[SQUARE_NB][SQUARE_NB]; + extern uint8_t PopCnt16[1 << 16]; + extern uint8_t SquareDistance[SQUARE_NB][SQUARE_NB]; -extern Bitboard BetweenBB[SQUARE_NB][SQUARE_NB]; -extern Bitboard LineBB[SQUARE_NB][SQUARE_NB]; -extern Bitboard PseudoAttacks[PIECE_TYPE_NB][SQUARE_NB]; -extern Bitboard PawnAttacks[COLOR_NB][SQUARE_NB]; + extern Bitboard BetweenBB[SQUARE_NB][SQUARE_NB]; + extern Bitboard LineBB[SQUARE_NB][SQUARE_NB]; + extern Bitboard PseudoAttacks[PIECE_TYPE_NB][SQUARE_NB]; + extern Bitboard PawnAttacks[COLOR_NB][SQUARE_NB]; -/// Magic holds all magic bitboards relevant data for a single square -struct Magic { - Bitboard mask; - Bitboard magic; - Bitboard* attacks; - unsigned shift; + /// Magic holds all magic bitboards relevant data for a single square + struct Magic { + Bitboard mask; + Bitboard magic; + Bitboard* attacks; + unsigned shift; - // Compute the attack's index using the 'magic bitboards' approach - unsigned index(Bitboard occupied) const { + // Compute the attack's index using the 'magic bitboards' approach + unsigned index(Bitboard occupied) const { - if (HasPext) - return unsigned(pext(occupied, mask)); + if (HasPext) return unsigned(pext(occupied, mask)); - if (Is64Bit) - return unsigned(((occupied & mask) * magic) >> shift); + if (Is64Bit) return unsigned(((occupied & mask) * magic) >> shift); - unsigned lo = unsigned(occupied) & unsigned(mask); - unsigned hi = unsigned(occupied >> 32) & unsigned(mask >> 32); - return (lo * unsigned(magic) ^ hi * unsigned(magic >> 32)) >> shift; - } -}; + unsigned lo = unsigned(occupied) & unsigned(mask); + unsigned hi = unsigned(occupied >> 32) & unsigned(mask >> 32); + return (lo * unsigned(magic) ^ hi * unsigned(magic >> 32)) >> shift; + } + }; -extern Magic RookMagics[SQUARE_NB]; -extern Magic BishopMagics[SQUARE_NB]; + extern Magic RookMagics[SQUARE_NB]; + extern Magic BishopMagics[SQUARE_NB]; -inline Bitboard square_bb(Square s) { - assert(is_ok(s)); - return (1ULL << s); -} + inline Bitboard square_bb(Square s) { + assert(is_ok(s)); + return (1ULL << s); + } -/// Overloads of bitwise operators between a Bitboard and a Square for testing -/// whether a given bit is set in a bitboard, and for setting and clearing bits. + /// Overloads of bitwise operators between a Bitboard and a Square for testing + /// whether a given bit is set in a bitboard, and for setting and clearing bits. -inline Bitboard operator&( Bitboard b, Square s) { return b & square_bb(s); } -inline Bitboard operator|( Bitboard b, Square s) { return b | square_bb(s); } -inline Bitboard operator^( Bitboard b, Square s) { return b ^ square_bb(s); } -inline Bitboard& operator|=(Bitboard& b, Square s) { return b |= square_bb(s); } -inline Bitboard& operator^=(Bitboard& b, Square s) { return b ^= square_bb(s); } + inline Bitboard operator&(Bitboard b, Square s) { return b & square_bb(s); } + inline Bitboard operator|(Bitboard b, Square s) { return b | square_bb(s); } + inline Bitboard operator^(Bitboard b, Square s) { return b ^ square_bb(s); } + inline Bitboard& operator|=(Bitboard& b, Square s) { return b |= square_bb(s); } + inline Bitboard& operator^=(Bitboard& b, Square s) { return b ^= square_bb(s); } -inline Bitboard operator&(Square s, Bitboard b) { return b & s; } -inline Bitboard operator|(Square s, Bitboard b) { return b | s; } -inline Bitboard operator^(Square s, Bitboard b) { return b ^ s; } + inline Bitboard operator&(Square s, Bitboard b) { return b & s; } + inline Bitboard operator|(Square s, Bitboard b) { return b | s; } + inline Bitboard operator^(Square s, Bitboard b) { return b ^ s; } -inline Bitboard operator|(Square s1, Square s2) { return square_bb(s1) | s2; } + inline Bitboard operator|(Square s1, Square s2) { return square_bb(s1) | s2; } -constexpr bool more_than_one(Bitboard b) { - return b & (b - 1); -} + constexpr bool more_than_one(Bitboard b) { return b & (b - 1); } -/// rank_bb() and file_bb() return a bitboard representing all the squares on -/// the given file or rank. + /// rank_bb() and file_bb() return a bitboard representing all the squares on + /// the given file or rank. -constexpr Bitboard rank_bb(Rank r) { - return Rank1BB << (8 * r); -} + constexpr Bitboard rank_bb(Rank r) { return Rank1BB << (8 * r); } -constexpr Bitboard rank_bb(Square s) { - return rank_bb(rank_of(s)); -} + constexpr Bitboard rank_bb(Square s) { return rank_bb(rank_of(s)); } -constexpr Bitboard file_bb(File f) { - return FileABB << f; -} + constexpr Bitboard file_bb(File f) { return FileABB << f; } -constexpr Bitboard file_bb(Square s) { - return file_bb(file_of(s)); -} + constexpr Bitboard file_bb(Square s) { return file_bb(file_of(s)); } -/// shift() moves a bitboard one or two steps as specified by the direction D + /// shift() moves a bitboard one or two steps as specified by the direction D -template -constexpr Bitboard shift(Bitboard b) { - return D == NORTH ? b << 8 : D == SOUTH ? b >> 8 - : D == NORTH+NORTH? b <<16 : D == SOUTH+SOUTH? b >>16 - : D == EAST ? (b & ~FileHBB) << 1 : D == WEST ? (b & ~FileABB) >> 1 - : D == NORTH_EAST ? (b & ~FileHBB) << 9 : D == NORTH_WEST ? (b & ~FileABB) << 7 - : D == SOUTH_EAST ? (b & ~FileHBB) >> 7 : D == SOUTH_WEST ? (b & ~FileABB) >> 9 - : 0; -} + template constexpr Bitboard shift(Bitboard b) { + return D == NORTH ? b << 8 : + D == SOUTH ? b >> 8 : + D == NORTH + NORTH ? b << 16 : + D == SOUTH + SOUTH ? b >> 16 : + D == EAST ? (b & ~FileHBB) << 1 : + D == WEST ? (b & ~FileABB) >> 1 : + D == NORTH_EAST ? (b & ~FileHBB) << 9 : + D == NORTH_WEST ? (b & ~FileABB) << 7 : + D == SOUTH_EAST ? (b & ~FileHBB) >> 7 : + D == SOUTH_WEST ? (b & ~FileABB) >> 9 : + 0; + } -/// pawn_attacks_bb() returns the squares attacked by pawns of the given color -/// from the squares in the given bitboard. + /// pawn_attacks_bb() returns the squares attacked by pawns of the given color + /// from the squares in the given bitboard. -template -constexpr Bitboard pawn_attacks_bb(Bitboard b) { - return C == WHITE ? shift(b) | shift(b) - : shift(b) | shift(b); -} + template constexpr Bitboard pawn_attacks_bb(Bitboard b) { + return C == WHITE ? shift(b) | shift(b) : + shift(b) | shift(b); + } -inline Bitboard pawn_attacks_bb(Color c, Square s) { + inline Bitboard pawn_attacks_bb(Color c, Square s) { - assert(is_ok(s)); - return PawnAttacks[c][s]; -} + assert(is_ok(s)); + return PawnAttacks[c][s]; + } -/// line_bb() returns a bitboard representing an entire line (from board edge -/// to board edge) that intersects the two given squares. If the given squares -/// are not on a same file/rank/diagonal, the function returns 0. For instance, -/// line_bb(SQ_C4, SQ_F7) will return a bitboard with the A2-G8 diagonal. + /// line_bb() returns a bitboard representing an entire line (from board edge + /// to board edge) that intersects the two given squares. If the given squares + /// are not on a same file/rank/diagonal, the function returns 0. For instance, + /// line_bb(SQ_C4, SQ_F7) will return a bitboard with the A2-G8 diagonal. -inline Bitboard line_bb(Square s1, Square s2) { + inline Bitboard line_bb(Square s1, Square s2) { - assert(is_ok(s1) && is_ok(s2)); + assert(is_ok(s1) && is_ok(s2)); - return LineBB[s1][s2]; -} + return LineBB[s1][s2]; + } -/// between_bb(s1, s2) returns a bitboard representing the squares in the semi-open -/// segment between the squares s1 and s2 (excluding s1 but including s2). If the -/// given squares are not on a same file/rank/diagonal, it returns s2. For instance, -/// between_bb(SQ_C4, SQ_F7) will return a bitboard with squares D5, E6 and F7, but -/// between_bb(SQ_E6, SQ_F8) will return a bitboard with the square F8. This trick -/// allows to generate non-king evasion moves faster: the defending piece must either -/// interpose itself to cover the check or capture the checking piece. + /// between_bb(s1, s2) returns a bitboard representing the squares in the semi-open + /// segment between the squares s1 and s2 (excluding s1 but including s2). If the + /// given squares are not on a same file/rank/diagonal, it returns s2. For instance, + /// between_bb(SQ_C4, SQ_F7) will return a bitboard with squares D5, E6 and F7, but + /// between_bb(SQ_E6, SQ_F8) will return a bitboard with the square F8. This trick + /// allows to generate non-king evasion moves faster: the defending piece must either + /// interpose itself to cover the check or capture the checking piece. -inline Bitboard between_bb(Square s1, Square s2) { + inline Bitboard between_bb(Square s1, Square s2) { - assert(is_ok(s1) && is_ok(s2)); + assert(is_ok(s1) && is_ok(s2)); - return BetweenBB[s1][s2]; -} + return BetweenBB[s1][s2]; + } -/// aligned() returns true if the squares s1, s2 and s3 are aligned either on a -/// straight or on a diagonal line. + /// aligned() returns true if the squares s1, s2 and s3 are aligned either on a + /// straight or on a diagonal line. -inline bool aligned(Square s1, Square s2, Square s3) { - return line_bb(s1, s2) & s3; -} + inline bool aligned(Square s1, Square s2, Square s3) { return line_bb(s1, s2) & s3; } -/// distance() functions return the distance between x and y, defined as the -/// number of steps for a king in x to reach y. + /// distance() functions return the distance between x and y, defined as the + /// number of steps for a king in x to reach y. -template inline int distance(Square x, Square y); -template<> inline int distance(Square x, Square y) { return std::abs(file_of(x) - file_of(y)); } -template<> inline int distance(Square x, Square y) { return std::abs(rank_of(x) - rank_of(y)); } -template<> inline int distance(Square x, Square y) { return SquareDistance[x][y]; } + template inline int distance(Square x, Square y); + template<> inline int distance(Square x, Square y) { + return std::abs(file_of(x) - file_of(y)); + } + template<> inline int distance(Square x, Square y) { + return std::abs(rank_of(x) - rank_of(y)); + } + template<> inline int distance(Square x, Square y) { return SquareDistance[x][y]; } -inline int edge_distance(File f) { return std::min(f, File(FILE_H - f)); } + inline int edge_distance(File f) { return std::min(f, File(FILE_H - f)); } -/// attacks_bb(Square) returns the pseudo attacks of the give piece type -/// assuming an empty board. + /// attacks_bb(Square) returns the pseudo attacks of the give piece type + /// assuming an empty board. -template -inline Bitboard attacks_bb(Square s) { + template inline Bitboard attacks_bb(Square s) { - assert((Pt != PAWN) && (is_ok(s))); + assert((Pt != PAWN) && (is_ok(s))); - return PseudoAttacks[Pt][s]; -} + return PseudoAttacks[Pt][s]; + } -/// attacks_bb(Square, Bitboard) returns the attacks by the given piece -/// assuming the board is occupied according to the passed Bitboard. -/// Sliding piece attacks do not continue passed an occupied square. + /// attacks_bb(Square, Bitboard) returns the attacks by the given piece + /// assuming the board is occupied according to the passed Bitboard. + /// Sliding piece attacks do not continue passed an occupied square. -template -inline Bitboard attacks_bb(Square s, Bitboard occupied) { + template inline Bitboard attacks_bb(Square s, Bitboard occupied) { - assert((Pt != PAWN) && (is_ok(s))); + assert((Pt != PAWN) && (is_ok(s))); - switch (Pt) - { - case BISHOP: return BishopMagics[s].attacks[BishopMagics[s].index(occupied)]; - case ROOK : return RookMagics[s].attacks[ RookMagics[s].index(occupied)]; - case QUEEN : return attacks_bb(s, occupied) | attacks_bb(s, occupied); - default : return PseudoAttacks[Pt][s]; - } -} + switch (Pt) { + case BISHOP : return BishopMagics[s].attacks[BishopMagics[s].index(occupied)]; + case ROOK : return RookMagics[s].attacks[RookMagics[s].index(occupied)]; + case QUEEN : return attacks_bb(s, occupied) | attacks_bb(s, occupied); + default : return PseudoAttacks[Pt][s]; + } + } -inline Bitboard attacks_bb(PieceType pt, Square s, Bitboard occupied) { + inline Bitboard attacks_bb(PieceType pt, Square s, Bitboard occupied) { - assert((pt != PAWN) && (is_ok(s))); + assert((pt != PAWN) && (is_ok(s))); - switch (pt) - { - case BISHOP: return attacks_bb(s, occupied); - case ROOK : return attacks_bb< ROOK>(s, occupied); - case QUEEN : return attacks_bb(s, occupied) | attacks_bb(s, occupied); - default : return PseudoAttacks[pt][s]; - } -} + switch (pt) { + case BISHOP : return attacks_bb(s, occupied); + case ROOK : return attacks_bb(s, occupied); + case QUEEN : return attacks_bb(s, occupied) | attacks_bb(s, occupied); + default : return PseudoAttacks[pt][s]; + } + } -/// popcount() counts the number of non-zero bits in a bitboard + /// popcount() counts the number of non-zero bits in a bitboard -inline int popcount(Bitboard b) { + inline int popcount(Bitboard b) { #ifndef USE_POPCNT - union { Bitboard bb; uint16_t u[4]; } v = { b }; - return PopCnt16[v.u[0]] + PopCnt16[v.u[1]] + PopCnt16[v.u[2]] + PopCnt16[v.u[3]]; + union { + Bitboard bb; + uint16_t u[4]; + } v = {b}; + return PopCnt16[v.u[0]] + PopCnt16[v.u[1]] + PopCnt16[v.u[2]] + PopCnt16[v.u[3]]; #elif defined(_MSC_VER) - return (int)_mm_popcnt_u64(b); + return (int) _mm_popcnt_u64(b); -#else // Assumed gcc or compatible compiler +#else // Assumed gcc or compatible compiler - return __builtin_popcountll(b); + return __builtin_popcountll(b); #endif -} + } -/// lsb() and msb() return the least/most significant bit in a non-zero bitboard + /// lsb() and msb() return the least/most significant bit in a non-zero bitboard #if defined(__GNUC__) // GCC, Clang, ICX -inline Square lsb(Bitboard b) { - assert(b); - return Square(__builtin_ctzll(b)); -} + inline Square lsb(Bitboard b) { + assert(b); + return Square(__builtin_ctzll(b)); + } -inline Square msb(Bitboard b) { - assert(b); - return Square(63 ^ __builtin_clzll(b)); -} + inline Square msb(Bitboard b) { + assert(b); + return Square(63 ^ __builtin_clzll(b)); + } #elif defined(_MSC_VER) // MSVC -#ifdef _WIN64 // MSVC, WIN64 - -inline Square lsb(Bitboard b) { - assert(b); - unsigned long idx; - _BitScanForward64(&idx, b); - return (Square) idx; -} - -inline Square msb(Bitboard b) { - assert(b); - unsigned long idx; - _BitScanReverse64(&idx, b); - return (Square) idx; -} - -#else // MSVC, WIN32 - -inline Square lsb(Bitboard b) { - assert(b); - unsigned long idx; - - if (b & 0xffffffff) { - _BitScanForward(&idx, int32_t(b)); - return Square(idx); - } else { - _BitScanForward(&idx, int32_t(b >> 32)); - return Square(idx + 32); - } -} - -inline Square msb(Bitboard b) { - assert(b); - unsigned long idx; - - if (b >> 32) { - _BitScanReverse(&idx, int32_t(b >> 32)); - return Square(idx + 32); - } else { - _BitScanReverse(&idx, int32_t(b)); - return Square(idx); - } -} - -#endif + #ifdef _WIN64 // MSVC, WIN64 + + inline Square lsb(Bitboard b) { + assert(b); + unsigned long idx; + _BitScanForward64(&idx, b); + return (Square) idx; + } + + inline Square msb(Bitboard b) { + assert(b); + unsigned long idx; + _BitScanReverse64(&idx, b); + return (Square) idx; + } + + #else // MSVC, WIN32 + + inline Square lsb(Bitboard b) { + assert(b); + unsigned long idx; + + if (b & 0xffffffff) { + _BitScanForward(&idx, int32_t(b)); + return Square(idx); + } else { + _BitScanForward(&idx, int32_t(b >> 32)); + return Square(idx + 32); + } + } + + inline Square msb(Bitboard b) { + assert(b); + unsigned long idx; + + if (b >> 32) { + _BitScanReverse(&idx, int32_t(b >> 32)); + return Square(idx + 32); + } else { + _BitScanReverse(&idx, int32_t(b)); + return Square(idx); + } + } + + #endif #else // Compiler is neither GCC nor MSVC compatible -#error "Compiler not supported." + #error "Compiler not supported." #endif -/// least_significant_square_bb() returns the bitboard of the least significant -/// square of a non-zero bitboard. It is equivalent to square_bb(lsb(bb)). + /// least_significant_square_bb() returns the bitboard of the least significant + /// square of a non-zero bitboard. It is equivalent to square_bb(lsb(bb)). -inline Bitboard least_significant_square_bb(Bitboard b) { - assert(b); - return b & -b; -} + inline Bitboard least_significant_square_bb(Bitboard b) { + assert(b); + return b & -b; + } -/// pop_lsb() finds and clears the least significant bit in a non-zero bitboard + /// pop_lsb() finds and clears the least significant bit in a non-zero bitboard -inline Square pop_lsb(Bitboard& b) { - assert(b); - const Square s = lsb(b); - b &= b - 1; - return s; -} + inline Square pop_lsb(Bitboard& b) { + assert(b); + const Square s = lsb(b); + b &= b - 1; + return s; + } -} // namespace Stockfish +} // namespace Stockfish -#endif // #ifndef BITBOARD_H_INCLUDED +#endif // #ifndef BITBOARD_H_INCLUDED diff --git a/src/evaluate.cpp b/src/evaluate.cpp index 9ca0e4566f6..bbcd4f5f628 100644 --- a/src/evaluate.cpp +++ b/src/evaluate.cpp @@ -43,186 +43,187 @@ // const unsigned int gEmbeddedNNUESize; // the size of the embedded file // Note that this does not work in Microsoft Visual Studio. #if !defined(_MSC_VER) && !defined(NNUE_EMBEDDING_OFF) - INCBIN(EmbeddedNNUE, EvalFileDefaultName); +INCBIN(EmbeddedNNUE, EvalFileDefaultName); #else - const unsigned char gEmbeddedNNUEData[1] = {0x0}; - const unsigned char *const gEmbeddedNNUEEnd = &gEmbeddedNNUEData[1]; - const unsigned int gEmbeddedNNUESize = 1; +const unsigned char gEmbeddedNNUEData[1] = {0x0}; +const unsigned char* const gEmbeddedNNUEEnd = &gEmbeddedNNUEData[1]; +const unsigned int gEmbeddedNNUESize = 1; #endif namespace Stockfish { -namespace Eval { - - std::string currentEvalFileName = "None"; - - /// NNUE::init() tries to load a NNUE network at startup time, or when the engine - /// receives a UCI command "setoption name EvalFile value nn-[a-z0-9]{12}.nnue" - /// The name of the NNUE network is always retrieved from the EvalFile option. - /// We search the given network in three locations: internally (the default - /// network may be embedded in the binary), in the active working directory and - /// in the engine directory. Distro packagers may define the DEFAULT_NNUE_DIRECTORY - /// variable to have the engine search in a special directory in their distro. - - void NNUE::init() { - - std::string eval_file = std::string(Options["EvalFile"]); - if (eval_file.empty()) - eval_file = EvalFileDefaultName; - - #if defined(DEFAULT_NNUE_DIRECTORY) - std::vector dirs = { "" , "" , CommandLine::binaryDirectory , stringify(DEFAULT_NNUE_DIRECTORY) }; - #else - std::vector dirs = { "" , "" , CommandLine::binaryDirectory }; - #endif - - for (const std::string& directory : dirs) - if (currentEvalFileName != eval_file) - { - if (directory != "") - { - std::ifstream stream(directory + eval_file, std::ios::binary); - if (NNUE::load_eval(eval_file, stream)) - currentEvalFileName = eval_file; - } + namespace Eval { - if (directory == "" && eval_file == EvalFileDefaultName) - { - // C++ way to prepare a buffer for a memory stream - class MemoryBuffer : public std::basic_streambuf { - public: MemoryBuffer(char* p, size_t n) { setg(p, p, p + n); setp(p, p + n); } - }; + std::string currentEvalFileName = "None"; - MemoryBuffer buffer(const_cast(reinterpret_cast(gEmbeddedNNUEData)), - size_t(gEmbeddedNNUESize)); - (void) gEmbeddedNNUEEnd; // Silence warning on unused variable + /// NNUE::init() tries to load a NNUE network at startup time, or when the engine + /// receives a UCI command "setoption name EvalFile value nn-[a-z0-9]{12}.nnue" + /// The name of the NNUE network is always retrieved from the EvalFile option. + /// We search the given network in three locations: internally (the default + /// network may be embedded in the binary), in the active working directory and + /// in the engine directory. Distro packagers may define the DEFAULT_NNUE_DIRECTORY + /// variable to have the engine search in a special directory in their distro. - std::istream stream(&buffer); - if (NNUE::load_eval(eval_file, stream)) - currentEvalFileName = eval_file; - } + void NNUE::init() { + + std::string eval_file = std::string(Options["EvalFile"]); + if (eval_file.empty()) eval_file = EvalFileDefaultName; + +#if defined(DEFAULT_NNUE_DIRECTORY) + std::vector dirs = {"", "", CommandLine::binaryDirectory, + stringify(DEFAULT_NNUE_DIRECTORY)}; +#else + std::vector dirs = {"", "", CommandLine::binaryDirectory}; +#endif + + for (const std::string& directory : dirs) + if (currentEvalFileName != eval_file) { + if (directory != "") { + std::ifstream stream(directory + eval_file, std::ios::binary); + if (NNUE::load_eval(eval_file, stream)) currentEvalFileName = eval_file; + } + + if (directory == "" && eval_file == EvalFileDefaultName) { + // C++ way to prepare a buffer for a memory stream + class MemoryBuffer: public std::basic_streambuf { + public: + MemoryBuffer(char* p, size_t n) { + setg(p, p, p + n); + setp(p, p + n); + } + }; + + MemoryBuffer buffer( + const_cast(reinterpret_cast(gEmbeddedNNUEData)), + size_t(gEmbeddedNNUESize)); + (void) gEmbeddedNNUEEnd; // Silence warning on unused variable + + std::istream stream(&buffer); + if (NNUE::load_eval(eval_file, stream)) currentEvalFileName = eval_file; + } + } } - } - /// NNUE::verify() verifies that the last net used was loaded successfully - void NNUE::verify() { + /// NNUE::verify() verifies that the last net used was loaded successfully + void NNUE::verify() { - std::string eval_file = std::string(Options["EvalFile"]); - if (eval_file.empty()) - eval_file = EvalFileDefaultName; + std::string eval_file = std::string(Options["EvalFile"]); + if (eval_file.empty()) eval_file = EvalFileDefaultName; - if (currentEvalFileName != eval_file) - { + if (currentEvalFileName != eval_file) { - std::string msg1 = "Network evaluation parameters compatible with the engine must be available."; - std::string msg2 = "The network file " + eval_file + " was not loaded successfully."; - std::string msg3 = "The UCI option EvalFile might need to specify the full path, including the directory name, to the network file."; - std::string msg4 = "The default net can be downloaded from: https://tests.stockfishchess.org/api/nn/" + std::string(EvalFileDefaultName); - std::string msg5 = "The engine will be terminated now."; + std::string msg1 = + "Network evaluation parameters compatible with the engine must be available."; + std::string msg2 = + "The network file " + eval_file + " was not loaded successfully."; + std::string msg3 = + "The UCI option EvalFile might need to specify the full path, including the directory name, to the network file."; + std::string msg4 = + "The default net can be downloaded from: https://tests.stockfishchess.org/api/nn/" + + std::string(EvalFileDefaultName); + std::string msg5 = "The engine will be terminated now."; - sync_cout << "info string ERROR: " << msg1 << sync_endl; - sync_cout << "info string ERROR: " << msg2 << sync_endl; - sync_cout << "info string ERROR: " << msg3 << sync_endl; - sync_cout << "info string ERROR: " << msg4 << sync_endl; - sync_cout << "info string ERROR: " << msg5 << sync_endl; + sync_cout << "info string ERROR: " << msg1 << sync_endl; + sync_cout << "info string ERROR: " << msg2 << sync_endl; + sync_cout << "info string ERROR: " << msg3 << sync_endl; + sync_cout << "info string ERROR: " << msg4 << sync_endl; + sync_cout << "info string ERROR: " << msg5 << sync_endl; - exit(EXIT_FAILURE); - } + exit(EXIT_FAILURE); + } - sync_cout << "info string NNUE evaluation using " << eval_file << sync_endl; - } -} + sync_cout << "info string NNUE evaluation using " << eval_file << sync_endl; + } + } -/// simple_eval() returns a static, purely materialistic evaluation of the position -/// from the point of view of the given color. It can be divided by PawnValue to get -/// an approximation of the material advantage on the board in terms of pawns. + /// simple_eval() returns a static, purely materialistic evaluation of the position + /// from the point of view of the given color. It can be divided by PawnValue to get + /// an approximation of the material advantage on the board in terms of pawns. -Value Eval::simple_eval(const Position& pos, Color c) { - return PawnValue * (pos.count(c) - pos.count(~c)) - + (pos.non_pawn_material(c) - pos.non_pawn_material(~c)); -} + Value Eval::simple_eval(const Position& pos, Color c) { + return PawnValue * (pos.count(c) - pos.count(~c)) + + (pos.non_pawn_material(c) - pos.non_pawn_material(~c)); + } -/// evaluate() is the evaluator for the outer world. It returns a static evaluation -/// of the position from the point of view of the side to move. + /// evaluate() is the evaluator for the outer world. It returns a static evaluation + /// of the position from the point of view of the side to move. -Value Eval::evaluate(const Position& pos) { + Value Eval::evaluate(const Position& pos) { - assert(!pos.checkers()); + assert(!pos.checkers()); - Value v; - Color stm = pos.side_to_move(); - int shuffling = pos.rule50_count(); - int simpleEval = simple_eval(pos, stm) + (int(pos.key() & 7) - 3); + Value v; + Color stm = pos.side_to_move(); + int shuffling = pos.rule50_count(); + int simpleEval = simple_eval(pos, stm) + (int(pos.key() & 7) - 3); - bool lazy = abs(simpleEval) >= RookValue + KnightValue - + 16 * shuffling * shuffling - + abs(pos.this_thread()->bestValue) - + abs(pos.this_thread()->rootSimpleEval); + bool lazy = abs(simpleEval) >= RookValue + KnightValue + 16 * shuffling * shuffling + + abs(pos.this_thread()->bestValue) + + abs(pos.this_thread()->rootSimpleEval); - if (lazy) - v = Value(simpleEval); - else - { - int nnueComplexity; - Value nnue = NNUE::evaluate(pos, true, &nnueComplexity); + if (lazy) + v = Value(simpleEval); + else { + int nnueComplexity; + Value nnue = NNUE::evaluate(pos, true, &nnueComplexity); - Value optimism = pos.this_thread()->optimism[stm]; + Value optimism = pos.this_thread()->optimism[stm]; - // Blend optimism and eval with nnue complexity and material imbalance - optimism += optimism * (nnueComplexity + abs(simpleEval - nnue)) / 512; - nnue -= nnue * (nnueComplexity + abs(simpleEval - nnue)) / 32768; + // Blend optimism and eval with nnue complexity and material imbalance + optimism += optimism * (nnueComplexity + abs(simpleEval - nnue)) / 512; + nnue -= nnue * (nnueComplexity + abs(simpleEval - nnue)) / 32768; - int npm = pos.non_pawn_material() / 64; - v = ( nnue * (915 + npm + 9 * pos.count()) - + optimism * (154 + npm + pos.count())) / 1024; - } + int npm = pos.non_pawn_material() / 64; + v = (nnue * (915 + npm + 9 * pos.count()) + + optimism * (154 + npm + pos.count())) / + 1024; + } - // Damp down the evaluation linearly when shuffling - v = v * (200 - shuffling) / 214; + // Damp down the evaluation linearly when shuffling + v = v * (200 - shuffling) / 214; - // Guarantee evaluation does not hit the tablebase range - v = std::clamp(v, VALUE_TB_LOSS_IN_MAX_PLY + 1, VALUE_TB_WIN_IN_MAX_PLY - 1); + // Guarantee evaluation does not hit the tablebase range + v = std::clamp(v, VALUE_TB_LOSS_IN_MAX_PLY + 1, VALUE_TB_WIN_IN_MAX_PLY - 1); - return v; -} + return v; + } -/// trace() is like evaluate(), but instead of returning a value, it returns -/// a string (suitable for outputting to stdout) that contains the detailed -/// descriptions and values of each evaluation term. Useful for debugging. -/// Trace scores are from white's point of view + /// trace() is like evaluate(), but instead of returning a value, it returns + /// a string (suitable for outputting to stdout) that contains the detailed + /// descriptions and values of each evaluation term. Useful for debugging. + /// Trace scores are from white's point of view -std::string Eval::trace(Position& pos) { + std::string Eval::trace(Position& pos) { - if (pos.checkers()) - return "Final evaluation: none (in check)"; + if (pos.checkers()) return "Final evaluation: none (in check)"; - // Reset any global variable used in eval - pos.this_thread()->bestValue = VALUE_ZERO; - pos.this_thread()->rootSimpleEval = VALUE_ZERO; - pos.this_thread()->optimism[WHITE] = VALUE_ZERO; - pos.this_thread()->optimism[BLACK] = VALUE_ZERO; + // Reset any global variable used in eval + pos.this_thread()->bestValue = VALUE_ZERO; + pos.this_thread()->rootSimpleEval = VALUE_ZERO; + pos.this_thread()->optimism[WHITE] = VALUE_ZERO; + pos.this_thread()->optimism[BLACK] = VALUE_ZERO; - std::stringstream ss; - ss << std::showpoint << std::noshowpos << std::fixed << std::setprecision(2); - ss << '\n' << NNUE::trace(pos) << '\n'; + std::stringstream ss; + ss << std::showpoint << std::noshowpos << std::fixed << std::setprecision(2); + ss << '\n' << NNUE::trace(pos) << '\n'; - ss << std::showpoint << std::showpos << std::fixed << std::setprecision(2) << std::setw(15); + ss << std::showpoint << std::showpos << std::fixed << std::setprecision(2) << std::setw(15); - Value v; - v = NNUE::evaluate(pos, false); - v = pos.side_to_move() == WHITE ? v : -v; - ss << "NNUE evaluation " << 0.01 * UCI::to_cp(v) << " (white side)\n"; + Value v; + v = NNUE::evaluate(pos, false); + v = pos.side_to_move() == WHITE ? v : -v; + ss << "NNUE evaluation " << 0.01 * UCI::to_cp(v) << " (white side)\n"; - v = evaluate(pos); - v = pos.side_to_move() == WHITE ? v : -v; - ss << "Final evaluation " << 0.01 * UCI::to_cp(v) << " (white side)"; - ss << " [with scaled NNUE, ...]"; - ss << "\n"; + v = evaluate(pos); + v = pos.side_to_move() == WHITE ? v : -v; + ss << "Final evaluation " << 0.01 * UCI::to_cp(v) << " (white side)"; + ss << " [with scaled NNUE, ...]"; + ss << "\n"; - return ss.str(); -} + return ss.str(); + } -} // namespace Stockfish +} // namespace Stockfish diff --git a/src/evaluate.h b/src/evaluate.h index 8ac24daea17..14c493612c1 100644 --- a/src/evaluate.h +++ b/src/evaluate.h @@ -25,32 +25,32 @@ namespace Stockfish { -class Position; -enum Value : int; + class Position; + enum Value : int; -namespace Eval { + namespace Eval { - std::string trace(Position& pos); + std::string trace(Position& pos); - Value simple_eval(const Position& pos, Color c); - Value evaluate(const Position& pos); + Value simple_eval(const Position& pos, Color c); + Value evaluate(const Position& pos); - extern std::string currentEvalFileName; + extern std::string currentEvalFileName; - // The default net name MUST follow the format nn-[SHA256 first 12 digits].nnue - // for the build process (profile-build and fishtest) to work. Do not change the - // name of the macro, as it is used in the Makefile. - #define EvalFileDefaultName "nn-1ee1aba5ed4c.nnue" +// The default net name MUST follow the format nn-[SHA256 first 12 digits].nnue +// for the build process (profile-build and fishtest) to work. Do not change the +// name of the macro, as it is used in the Makefile. +#define EvalFileDefaultName "nn-1ee1aba5ed4c.nnue" - namespace NNUE { + namespace NNUE { - void init(); - void verify(); + void init(); + void verify(); - } // namespace NNUE + } // namespace NNUE -} // namespace Eval + } // namespace Eval -} // namespace Stockfish +} // namespace Stockfish -#endif // #ifndef EVALUATE_H_INCLUDED +#endif // #ifndef EVALUATE_H_INCLUDED diff --git a/src/main.cpp b/src/main.cpp index eee149fb455..04879cc4673 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -33,19 +33,19 @@ using namespace Stockfish; int main(int argc, char* argv[]) { - std::cout << engine_info() << std::endl; + std::cout << engine_info() << std::endl; - CommandLine::init(argc, argv); - UCI::init(Options); - Tune::init(); - Bitboards::init(); - Position::init(); - Threads.set(size_t(Options["Threads"])); - Search::clear(); // After threads are up - Eval::NNUE::init(); + CommandLine::init(argc, argv); + UCI::init(Options); + Tune::init(); + Bitboards::init(); + Position::init(); + Threads.set(size_t(Options["Threads"])); + Search::clear(); // After threads are up + Eval::NNUE::init(); - UCI::loop(argc, argv); + UCI::loop(argc, argv); - Threads.set(0); - return 0; + Threads.set(0); + return 0; } diff --git a/src/misc.cpp b/src/misc.cpp index 83ea8e10fbf..9be55b2790e 100644 --- a/src/misc.cpp +++ b/src/misc.cpp @@ -19,30 +19,30 @@ #include "misc.h" #ifdef _WIN32 -#if _WIN32_WINNT < 0x0601 -#undef _WIN32_WINNT -#define _WIN32_WINNT 0x0601 // Force to include needed API prototypes -#endif + #if _WIN32_WINNT < 0x0601 + #undef _WIN32_WINNT + #define _WIN32_WINNT 0x0601 // Force to include needed API prototypes + #endif -#ifndef NOMINMAX -#define NOMINMAX -#endif + #ifndef NOMINMAX + #define NOMINMAX + #endif -#include + #include // The needed Windows API for processor groups could be missed from old Windows // versions, so instead of calling them directly (forcing the linker to resolve // the calls at compile time), try to load them at runtime. To do this we need // first to define the corresponding function pointers. extern "C" { -using fun1_t = bool(*)(LOGICAL_PROCESSOR_RELATIONSHIP, - PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, PDWORD); -using fun2_t = bool(*)(USHORT, PGROUP_AFFINITY); -using fun3_t = bool(*)(HANDLE, CONST GROUP_AFFINITY*, PGROUP_AFFINITY); -using fun4_t = bool(*)(USHORT, PGROUP_AFFINITY, USHORT, PUSHORT); -using fun5_t = WORD(*)(); -using fun6_t = bool(*)(HANDLE, DWORD, PHANDLE); -using fun7_t = bool(*)(LPCSTR, LPCSTR, PLUID); -using fun8_t = bool(*)(HANDLE, BOOL, PTOKEN_PRIVILEGES, DWORD, PTOKEN_PRIVILEGES, PDWORD); +using fun1_t = bool (*)(LOGICAL_PROCESSOR_RELATIONSHIP, PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, + PDWORD); +using fun2_t = bool (*)(USHORT, PGROUP_AFFINITY); +using fun3_t = bool (*)(HANDLE, CONST GROUP_AFFINITY*, PGROUP_AFFINITY); +using fun4_t = bool (*)(USHORT, PGROUP_AFFINITY, USHORT, PUSHORT); +using fun5_t = WORD (*)(); +using fun6_t = bool (*)(HANDLE, DWORD, PHANDLE); +using fun7_t = bool (*)(LPCSTR, LPCSTR, PLUID); +using fun8_t = bool (*)(HANDLE, BOOL, PTOKEN_PRIVILEGES, DWORD, PTOKEN_PRIVILEGES, PDWORD); } #endif @@ -59,358 +59,341 @@ using fun8_t = bool(*)(HANDLE, BOOL, PTOKEN_PRIVILEGES, DWORD, PTOKEN_PRIVILEGES #include "types.h" #if defined(__linux__) && !defined(__ANDROID__) -#include + #include #endif -#if defined(__APPLE__) || defined(__ANDROID__) || defined(__OpenBSD__) || (defined(__GLIBCXX__) && !defined(_GLIBCXX_HAVE_ALIGNED_ALLOC) && !defined(_WIN32)) || defined(__e2k__) -#define POSIXALIGNEDALLOC -#include +#if defined(__APPLE__) || defined(__ANDROID__) || defined(__OpenBSD__) || \ + (defined(__GLIBCXX__) && !defined(_GLIBCXX_HAVE_ALIGNED_ALLOC) && !defined(_WIN32)) || \ + defined(__e2k__) + #define POSIXALIGNEDALLOC + #include #endif namespace Stockfish { -namespace { + namespace { -/// Version number or dev. -constexpr std::string_view version = "dev"; + /// Version number or dev. + constexpr std::string_view version = "dev"; -/// Our fancy logging facility. The trick here is to replace cin.rdbuf() and -/// cout.rdbuf() with two Tie objects that tie cin and cout to a file stream. We -/// can toggle the logging of std::cout and std:cin at runtime whilst preserving -/// usual I/O functionality, all without changing a single line of code! -/// Idea from http://groups.google.com/group/comp.lang.c++/msg/1d941c0f26ea0d81 + /// Our fancy logging facility. The trick here is to replace cin.rdbuf() and + /// cout.rdbuf() with two Tie objects that tie cin and cout to a file stream. We + /// can toggle the logging of std::cout and std:cin at runtime whilst preserving + /// usual I/O functionality, all without changing a single line of code! + /// Idea from http://groups.google.com/group/comp.lang.c++/msg/1d941c0f26ea0d81 -struct Tie: public std::streambuf { // MSVC requires split streambuf for cin and cout + struct Tie: public std::streambuf { // MSVC requires split streambuf for cin and cout - Tie(std::streambuf* b, std::streambuf* l) : buf(b), logBuf(l) {} + Tie(std::streambuf* b, std::streambuf* l) : buf(b), logBuf(l) {} - int sync() override { return logBuf->pubsync(), buf->pubsync(); } - int overflow(int c) override { return log(buf->sputc((char)c), "<< "); } - int underflow() override { return buf->sgetc(); } - int uflow() override { return log(buf->sbumpc(), ">> "); } + int sync() override { return logBuf->pubsync(), buf->pubsync(); } + int overflow(int c) override { return log(buf->sputc((char) c), "<< "); } + int underflow() override { return buf->sgetc(); } + int uflow() override { return log(buf->sbumpc(), ">> "); } - std::streambuf *buf, *logBuf; + std::streambuf *buf, *logBuf; - int log(int c, const char* prefix) { + int log(int c, const char* prefix) { - static int last = '\n'; // Single log file + static int last = '\n'; // Single log file - if (last == '\n') - logBuf->sputn(prefix, 3); + if (last == '\n') logBuf->sputn(prefix, 3); - return last = logBuf->sputc((char)c); - } -}; + return last = logBuf->sputc((char) c); + } + }; -class Logger { + class Logger { - Logger() : in(std::cin.rdbuf(), file.rdbuf()), out(std::cout.rdbuf(), file.rdbuf()) {} - ~Logger() { start(""); } + Logger() : in(std::cin.rdbuf(), file.rdbuf()), out(std::cout.rdbuf(), file.rdbuf()) {} + ~Logger() { start(""); } - std::ofstream file; - Tie in, out; + std::ofstream file; + Tie in, out; -public: - static void start(const std::string& fname) { + public: + static void start(const std::string& fname) { - static Logger l; + static Logger l; - if (l.file.is_open()) - { - std::cout.rdbuf(l.out.buf); - std::cin.rdbuf(l.in.buf); - l.file.close(); - } + if (l.file.is_open()) { + std::cout.rdbuf(l.out.buf); + std::cin.rdbuf(l.in.buf); + l.file.close(); + } - if (!fname.empty()) - { - l.file.open(fname, std::ifstream::out); + if (!fname.empty()) { + l.file.open(fname, std::ifstream::out); - if (!l.file.is_open()) - { - std::cerr << "Unable to open debug log file " << fname << std::endl; - exit(EXIT_FAILURE); - } + if (!l.file.is_open()) { + std::cerr << "Unable to open debug log file " << fname << std::endl; + exit(EXIT_FAILURE); + } - std::cin.rdbuf(&l.in); - std::cout.rdbuf(&l.out); - } - } -}; - -} // namespace - - -/// engine_info() returns the full name of the current Stockfish version. -/// For local dev compiles we try to append the commit sha and commit date -/// from git if that fails only the local compilation date is set and "nogit" is specified: -/// Stockfish dev-YYYYMMDD-SHA -/// or -/// Stockfish dev-YYYYMMDD-nogit -/// -/// For releases (non dev builds) we only include the version number: -/// Stockfish version - -std::string engine_info(bool to_uci) { - std::stringstream ss; - ss << "Stockfish " << version << std::setfill('0'); - - if constexpr (version == "dev") - { - ss << "-"; - #ifdef GIT_DATE - ss << stringify(GIT_DATE); - #else - constexpr std::string_view months("Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec"); - std::string month, day, year; - std::stringstream date(__DATE__); // From compiler, format is "Sep 21 2008" - - date >> month >> day >> year; - ss << year << std::setw(2) << std::setfill('0') << (1 + months.find(month) / 4) << std::setw(2) << std::setfill('0') << day; - #endif - - ss << "-"; - - #ifdef GIT_SHA - ss << stringify(GIT_SHA); - #else - ss << "nogit"; - #endif - } - - ss << (to_uci ? "\nid author ": " by ") - << "the Stockfish developers (see AUTHORS file)"; - - return ss.str(); -} + std::cin.rdbuf(&l.in); + std::cout.rdbuf(&l.out); + } + } + }; + } // namespace -/// compiler_info() returns a string trying to describe the compiler we use -std::string compiler_info() { + /// engine_info() returns the full name of the current Stockfish version. + /// For local dev compiles we try to append the commit sha and commit date + /// from git if that fails only the local compilation date is set and "nogit" is specified: + /// Stockfish dev-YYYYMMDD-SHA + /// or + /// Stockfish dev-YYYYMMDD-nogit + /// + /// For releases (non dev builds) we only include the version number: + /// Stockfish version - #define make_version_string(major, minor, patch) stringify(major) "." stringify(minor) "." stringify(patch) + std::string engine_info(bool to_uci) { + std::stringstream ss; + ss << "Stockfish " << version << std::setfill('0'); -/// Predefined macros hell: -/// -/// __GNUC__ Compiler is GCC, Clang or ICX -/// __clang__ Compiler is Clang or ICX -/// __INTEL_LLVM_COMPILER Compiler is ICX -/// _MSC_VER Compiler is MSVC -/// _WIN32 Building on Windows (any) -/// _WIN64 Building on Windows 64 bit + if constexpr (version == "dev") { + ss << "-"; +#ifdef GIT_DATE + ss << stringify(GIT_DATE); +#else + constexpr std::string_view months("Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec"); + std::string month, day, year; + std::stringstream date(__DATE__); // From compiler, format is "Sep 21 2008" - std::string compiler = "\nCompiled by "; + date >> month >> day >> year; + ss << year << std::setw(2) << std::setfill('0') << (1 + months.find(month) / 4) + << std::setw(2) << std::setfill('0') << day; +#endif - #if defined(__INTEL_LLVM_COMPILER) - compiler += "ICX "; - compiler += stringify(__INTEL_LLVM_COMPILER); - #elif defined(__clang__) - compiler += "clang++ "; - compiler += make_version_string(__clang_major__, __clang_minor__, __clang_patchlevel__); - #elif _MSC_VER - compiler += "MSVC "; - compiler += "(version "; - compiler += stringify(_MSC_FULL_VER) "." stringify(_MSC_BUILD); - compiler += ")"; - #elif defined(__e2k__) && defined(__LCC__) + ss << "-"; + +#ifdef GIT_SHA + ss << stringify(GIT_SHA); +#else + ss << "nogit"; +#endif + } + + ss << (to_uci ? "\nid author " : " by ") << "the Stockfish developers (see AUTHORS file)"; + + return ss.str(); + } + + + /// compiler_info() returns a string trying to describe the compiler we use + + std::string compiler_info() { + +#define make_version_string(major, minor, patch) \ + stringify(major) "." stringify(minor) "." stringify(patch) + + /// Predefined macros hell: + /// + /// __GNUC__ Compiler is GCC, Clang or ICX + /// __clang__ Compiler is Clang or ICX + /// __INTEL_LLVM_COMPILER Compiler is ICX + /// _MSC_VER Compiler is MSVC + /// _WIN32 Building on Windows (any) + /// _WIN64 Building on Windows 64 bit + + std::string compiler = "\nCompiled by "; + +#if defined(__INTEL_LLVM_COMPILER) + compiler += "ICX "; + compiler += stringify(__INTEL_LLVM_COMPILER); +#elif defined(__clang__) + compiler += "clang++ "; + compiler += make_version_string(__clang_major__, __clang_minor__, __clang_patchlevel__); +#elif _MSC_VER + compiler += "MSVC "; + compiler += "(version "; + compiler += stringify(_MSC_FULL_VER) "." stringify(_MSC_BUILD); + compiler += ")"; +#elif defined(__e2k__) && defined(__LCC__) #define dot_ver2(n) \ - compiler += (char)'.'; \ - compiler += (char)('0' + (n) / 10); \ - compiler += (char)('0' + (n) % 10); - - compiler += "MCST LCC "; - compiler += "(version "; - compiler += std::to_string(__LCC__ / 100); - dot_ver2(__LCC__ % 100) - dot_ver2(__LCC_MINOR__) - compiler += ")"; - #elif __GNUC__ - compiler += "g++ (GNUC) "; - compiler += make_version_string(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__); - #else - compiler += "Unknown compiler "; - compiler += "(unknown version)"; - #endif - - #if defined(__APPLE__) - compiler += " on Apple"; - #elif defined(__CYGWIN__) - compiler += " on Cygwin"; - #elif defined(__MINGW64__) - compiler += " on MinGW64"; - #elif defined(__MINGW32__) - compiler += " on MinGW32"; - #elif defined(__ANDROID__) - compiler += " on Android"; - #elif defined(__linux__) - compiler += " on Linux"; - #elif defined(_WIN64) - compiler += " on Microsoft Windows 64-bit"; - #elif defined(_WIN32) - compiler += " on Microsoft Windows 32-bit"; - #else - compiler += " on unknown system"; - #endif - - compiler += "\nCompilation settings include: "; - compiler += (Is64Bit ? " 64bit" : " 32bit"); - #if defined(USE_VNNI) - compiler += " VNNI"; - #endif - #if defined(USE_AVX512) - compiler += " AVX512"; - #endif - compiler += (HasPext ? " BMI2" : ""); - #if defined(USE_AVX2) - compiler += " AVX2"; - #endif - #if defined(USE_SSE41) - compiler += " SSE41"; - #endif - #if defined(USE_SSSE3) - compiler += " SSSE3"; - #endif - #if defined(USE_SSE2) - compiler += " SSE2"; - #endif - compiler += (HasPopCnt ? " POPCNT" : ""); - #if defined(USE_MMX) - compiler += " MMX"; - #endif - #if defined(USE_NEON_DOTPROD) - compiler += " NEON_DOTPROD"; - #elif defined(USE_NEON) - compiler += " NEON"; - #endif - - #if !defined(NDEBUG) - compiler += " DEBUG"; - #endif - - compiler += "\n__VERSION__ macro expands to: "; - #ifdef __VERSION__ - compiler += __VERSION__; - #else - compiler += "(undefined macro)"; - #endif - compiler += "\n"; - - return compiler; -} + compiler += (char) '.'; \ + compiler += (char) ('0' + (n) / 10); \ + compiler += (char) ('0' + (n) % 10); + + compiler += "MCST LCC "; + compiler += "(version "; + compiler += std::to_string(__LCC__ / 100); + dot_ver2(__LCC__ % 100) dot_ver2(__LCC_MINOR__) compiler += ")"; +#elif __GNUC__ + compiler += "g++ (GNUC) "; + compiler += make_version_string(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__); +#else + compiler += "Unknown compiler "; + compiler += "(unknown version)"; +#endif +#if defined(__APPLE__) + compiler += " on Apple"; +#elif defined(__CYGWIN__) + compiler += " on Cygwin"; +#elif defined(__MINGW64__) + compiler += " on MinGW64"; +#elif defined(__MINGW32__) + compiler += " on MinGW32"; +#elif defined(__ANDROID__) + compiler += " on Android"; +#elif defined(__linux__) + compiler += " on Linux"; +#elif defined(_WIN64) + compiler += " on Microsoft Windows 64-bit"; +#elif defined(_WIN32) + compiler += " on Microsoft Windows 32-bit"; +#else + compiler += " on unknown system"; +#endif -/// Debug functions used mainly to collect run-time statistics -constexpr int MaxDebugSlots = 32; + compiler += "\nCompilation settings include: "; + compiler += (Is64Bit ? " 64bit" : " 32bit"); +#if defined(USE_VNNI) + compiler += " VNNI"; +#endif +#if defined(USE_AVX512) + compiler += " AVX512"; +#endif + compiler += (HasPext ? " BMI2" : ""); +#if defined(USE_AVX2) + compiler += " AVX2"; +#endif +#if defined(USE_SSE41) + compiler += " SSE41"; +#endif +#if defined(USE_SSSE3) + compiler += " SSSE3"; +#endif +#if defined(USE_SSE2) + compiler += " SSE2"; +#endif + compiler += (HasPopCnt ? " POPCNT" : ""); +#if defined(USE_MMX) + compiler += " MMX"; +#endif +#if defined(USE_NEON_DOTPROD) + compiler += " NEON_DOTPROD"; +#elif defined(USE_NEON) + compiler += " NEON"; +#endif -namespace { +#if !defined(NDEBUG) + compiler += " DEBUG"; +#endif -template -struct DebugInfo { - std::atomic data[N] = { 0 }; + compiler += "\n__VERSION__ macro expands to: "; +#ifdef __VERSION__ + compiler += __VERSION__; +#else + compiler += "(undefined macro)"; +#endif + compiler += "\n"; - constexpr inline std::atomic& operator[](int index) { return data[index]; } -}; + return compiler; + } -DebugInfo<2> hit[MaxDebugSlots]; -DebugInfo<2> mean[MaxDebugSlots]; -DebugInfo<3> stdev[MaxDebugSlots]; -DebugInfo<6> correl[MaxDebugSlots]; -} // namespace + /// Debug functions used mainly to collect run-time statistics + constexpr int MaxDebugSlots = 32; -void dbg_hit_on(bool cond, int slot) { + namespace { - ++hit[slot][0]; - if (cond) - ++hit[slot][1]; -} + template struct DebugInfo { + std::atomic data[N] = {0}; -void dbg_mean_of(int64_t value, int slot) { + constexpr inline std::atomic& operator[](int index) { return data[index]; } + }; - ++mean[slot][0]; - mean[slot][1] += value; -} + DebugInfo<2> hit[MaxDebugSlots]; + DebugInfo<2> mean[MaxDebugSlots]; + DebugInfo<3> stdev[MaxDebugSlots]; + DebugInfo<6> correl[MaxDebugSlots]; -void dbg_stdev_of(int64_t value, int slot) { + } // namespace - ++stdev[slot][0]; - stdev[slot][1] += value; - stdev[slot][2] += value * value; -} + void dbg_hit_on(bool cond, int slot) { -void dbg_correl_of(int64_t value1, int64_t value2, int slot) { + ++hit[slot][0]; + if (cond) ++hit[slot][1]; + } - ++correl[slot][0]; - correl[slot][1] += value1; - correl[slot][2] += value1 * value1; - correl[slot][3] += value2; - correl[slot][4] += value2 * value2; - correl[slot][5] += value1 * value2; -} + void dbg_mean_of(int64_t value, int slot) { -void dbg_print() { - - int64_t n; - auto E = [&n](int64_t x) { return double(x) / n; }; - auto sqr = [](double x) { return x * x; }; - - for (int i = 0; i < MaxDebugSlots; ++i) - if ((n = hit[i][0])) - std::cerr << "Hit #" << i - << ": Total " << n << " Hits " << hit[i][1] - << " Hit Rate (%) " << 100.0 * E(hit[i][1]) - << std::endl; - - for (int i = 0; i < MaxDebugSlots; ++i) - if ((n = mean[i][0])) - { - std::cerr << "Mean #" << i - << ": Total " << n << " Mean " << E(mean[i][1]) - << std::endl; - } + ++mean[slot][0]; + mean[slot][1] += value; + } - for (int i = 0; i < MaxDebugSlots; ++i) - if ((n = stdev[i][0])) - { - double r = sqrt(E(stdev[i][2]) - sqr(E(stdev[i][1]))); - std::cerr << "Stdev #" << i - << ": Total " << n << " Stdev " << r - << std::endl; - } + void dbg_stdev_of(int64_t value, int slot) { - for (int i = 0; i < MaxDebugSlots; ++i) - if ((n = correl[i][0])) - { - double r = (E(correl[i][5]) - E(correl[i][1]) * E(correl[i][3])) - / ( sqrt(E(correl[i][2]) - sqr(E(correl[i][1]))) - * sqrt(E(correl[i][4]) - sqr(E(correl[i][3])))); - std::cerr << "Correl. #" << i - << ": Total " << n << " Coefficient " << r - << std::endl; - } -} + ++stdev[slot][0]; + stdev[slot][1] += value; + stdev[slot][2] += value * value; + } + void dbg_correl_of(int64_t value1, int64_t value2, int slot) { -/// Used to serialize access to std::cout to avoid multiple threads writing at -/// the same time. + ++correl[slot][0]; + correl[slot][1] += value1; + correl[slot][2] += value1 * value1; + correl[slot][3] += value2; + correl[slot][4] += value2 * value2; + correl[slot][5] += value1 * value2; + } -std::ostream& operator<<(std::ostream& os, SyncCout sc) { + void dbg_print() { + + int64_t n; + auto E = [&n](int64_t x) { return double(x) / n; }; + auto sqr = [](double x) { return x * x; }; + + for (int i = 0; i < MaxDebugSlots; ++i) + if ((n = hit[i][0])) + std::cerr << "Hit #" << i << ": Total " << n << " Hits " << hit[i][1] + << " Hit Rate (%) " << 100.0 * E(hit[i][1]) << std::endl; + + for (int i = 0; i < MaxDebugSlots; ++i) + if ((n = mean[i][0])) { + std::cerr << "Mean #" << i << ": Total " << n << " Mean " << E(mean[i][1]) + << std::endl; + } + + for (int i = 0; i < MaxDebugSlots; ++i) + if ((n = stdev[i][0])) { + double r = sqrt(E(stdev[i][2]) - sqr(E(stdev[i][1]))); + std::cerr << "Stdev #" << i << ": Total " << n << " Stdev " << r << std::endl; + } + + for (int i = 0; i < MaxDebugSlots; ++i) + if ((n = correl[i][0])) { + double r = (E(correl[i][5]) - E(correl[i][1]) * E(correl[i][3])) / + (sqrt(E(correl[i][2]) - sqr(E(correl[i][1]))) * + sqrt(E(correl[i][4]) - sqr(E(correl[i][3])))); + std::cerr << "Correl. #" << i << ": Total " << n << " Coefficient " << r + << std::endl; + } + } - static std::mutex m; - if (sc == IO_LOCK) - m.lock(); + /// Used to serialize access to std::cout to avoid multiple threads writing at + /// the same time. - if (sc == IO_UNLOCK) - m.unlock(); + std::ostream& operator<<(std::ostream& os, SyncCout sc) { - return os; -} + static std::mutex m; + if (sc == IO_LOCK) m.lock(); -/// Trampoline helper to avoid moving Logger to misc.h -void start_logger(const std::string& fname) { Logger::start(fname); } + if (sc == IO_UNLOCK) m.unlock(); + + return os; + } + + + /// Trampoline helper to avoid moving Logger to misc.h + void start_logger(const std::string& fname) { Logger::start(fname); } /// prefetch() preloads the given address in L1/L2 cache. This is a non-blocking @@ -418,364 +401,341 @@ void start_logger(const std::string& fname) { Logger::start(fname); } /// which can be quite slow. #ifdef NO_PREFETCH -void prefetch(void*) {} + void prefetch(void*) {} #else -void prefetch(void* addr) { + void prefetch(void* addr) { -# if defined(_MSC_VER) - _mm_prefetch((char*)addr, _MM_HINT_T0); -# else - __builtin_prefetch(addr); -# endif -} + #if defined(_MSC_VER) + _mm_prefetch((char*) addr, _MM_HINT_T0); + #else + __builtin_prefetch(addr); + #endif + } #endif -/// std_aligned_alloc() is our wrapper for systems where the c++17 implementation -/// does not guarantee the availability of aligned_alloc(). Memory allocated with -/// std_aligned_alloc() must be freed with std_aligned_free(). + /// std_aligned_alloc() is our wrapper for systems where the c++17 implementation + /// does not guarantee the availability of aligned_alloc(). Memory allocated with + /// std_aligned_alloc() must be freed with std_aligned_free(). -void* std_aligned_alloc(size_t alignment, size_t size) { + void* std_aligned_alloc(size_t alignment, size_t size) { #if defined(POSIXALIGNEDALLOC) - void *mem; - return posix_memalign(&mem, alignment, size) ? nullptr : mem; + void* mem; + return posix_memalign(&mem, alignment, size) ? nullptr : mem; #elif defined(_WIN32) && !defined(_M_ARM) && !defined(_M_ARM64) - return _mm_malloc(size, alignment); + return _mm_malloc(size, alignment); #elif defined(_WIN32) - return _aligned_malloc(size, alignment); + return _aligned_malloc(size, alignment); #else - return std::aligned_alloc(alignment, size); + return std::aligned_alloc(alignment, size); #endif -} + } -void std_aligned_free(void* ptr) { + void std_aligned_free(void* ptr) { #if defined(POSIXALIGNEDALLOC) - free(ptr); + free(ptr); #elif defined(_WIN32) && !defined(_M_ARM) && !defined(_M_ARM64) - _mm_free(ptr); + _mm_free(ptr); #elif defined(_WIN32) - _aligned_free(ptr); + _aligned_free(ptr); #else - free(ptr); + free(ptr); #endif -} + } -/// aligned_large_pages_alloc() will return suitably aligned memory, if possible using large pages. + /// aligned_large_pages_alloc() will return suitably aligned memory, if possible using large pages. #if defined(_WIN32) -static void* aligned_large_pages_alloc_windows([[maybe_unused]] size_t allocSize) { - - #if !defined(_WIN64) - return nullptr; - #else - - HANDLE hProcessToken { }; - LUID luid { }; - void* mem = nullptr; - - const size_t largePageSize = GetLargePageMinimum(); - if (!largePageSize) - return nullptr; - - // Dynamically link OpenProcessToken, LookupPrivilegeValue and AdjustTokenPrivileges - - HMODULE hAdvapi32 = GetModuleHandle(TEXT("advapi32.dll")); - - if (!hAdvapi32) - hAdvapi32 = LoadLibrary(TEXT("advapi32.dll")); - - auto fun6 = (fun6_t)(void(*)())GetProcAddress(hAdvapi32, "OpenProcessToken"); - if (!fun6) - return nullptr; - auto fun7 = (fun7_t)(void(*)())GetProcAddress(hAdvapi32, "LookupPrivilegeValueA"); - if (!fun7) - return nullptr; - auto fun8 = (fun8_t)(void(*)())GetProcAddress(hAdvapi32, "AdjustTokenPrivileges"); - if (!fun8) - return nullptr; - - // We need SeLockMemoryPrivilege, so try to enable it for the process - if (!fun6( // OpenProcessToken() - GetCurrentProcess(), TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, &hProcessToken)) - return nullptr; - - if (fun7( // LookupPrivilegeValue(nullptr, SE_LOCK_MEMORY_NAME, &luid) - nullptr, "SeLockMemoryPrivilege", &luid)) - { - TOKEN_PRIVILEGES tp { }; - TOKEN_PRIVILEGES prevTp { }; - DWORD prevTpLen = 0; - - tp.PrivilegeCount = 1; - tp.Privileges[0].Luid = luid; - tp.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED; - - // Try to enable SeLockMemoryPrivilege. Note that even if AdjustTokenPrivileges() succeeds, - // we still need to query GetLastError() to ensure that the privileges were actually obtained. - if (fun8( // AdjustTokenPrivileges() - hProcessToken, FALSE, &tp, sizeof(TOKEN_PRIVILEGES), &prevTp, &prevTpLen) && - GetLastError() == ERROR_SUCCESS) - { - // Round up size to full pages and allocate - allocSize = (allocSize + largePageSize - 1) & ~size_t(largePageSize - 1); - mem = VirtualAlloc( - nullptr, allocSize, MEM_RESERVE | MEM_COMMIT | MEM_LARGE_PAGES, PAGE_READWRITE); - - // Privilege no longer needed, restore previous state - fun8( // AdjustTokenPrivileges () - hProcessToken, FALSE, &prevTp, 0, nullptr, nullptr); - } - } - - CloseHandle(hProcessToken); - - return mem; - - #endif -} + static void* aligned_large_pages_alloc_windows([[maybe_unused]] size_t allocSize) { + + #if !defined(_WIN64) + return nullptr; + #else + + HANDLE hProcessToken{}; + LUID luid{}; + void* mem = nullptr; + + const size_t largePageSize = GetLargePageMinimum(); + if (!largePageSize) return nullptr; + + // Dynamically link OpenProcessToken, LookupPrivilegeValue and AdjustTokenPrivileges + + HMODULE hAdvapi32 = GetModuleHandle(TEXT("advapi32.dll")); + + if (!hAdvapi32) hAdvapi32 = LoadLibrary(TEXT("advapi32.dll")); + + auto fun6 = (fun6_t) (void (*)()) GetProcAddress(hAdvapi32, "OpenProcessToken"); + if (!fun6) return nullptr; + auto fun7 = (fun7_t) (void (*)()) GetProcAddress(hAdvapi32, "LookupPrivilegeValueA"); + if (!fun7) return nullptr; + auto fun8 = (fun8_t) (void (*)()) GetProcAddress(hAdvapi32, "AdjustTokenPrivileges"); + if (!fun8) return nullptr; + + // We need SeLockMemoryPrivilege, so try to enable it for the process + if (!fun6( // OpenProcessToken() + GetCurrentProcess(), TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, &hProcessToken)) + return nullptr; + + if (fun7( // LookupPrivilegeValue(nullptr, SE_LOCK_MEMORY_NAME, &luid) + nullptr, "SeLockMemoryPrivilege", &luid)) { + TOKEN_PRIVILEGES tp{}; + TOKEN_PRIVILEGES prevTp{}; + DWORD prevTpLen = 0; + + tp.PrivilegeCount = 1; + tp.Privileges[0].Luid = luid; + tp.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED; + + // Try to enable SeLockMemoryPrivilege. Note that even if AdjustTokenPrivileges() succeeds, + // we still need to query GetLastError() to ensure that the privileges were actually obtained. + if (fun8( // AdjustTokenPrivileges() + hProcessToken, FALSE, &tp, sizeof(TOKEN_PRIVILEGES), &prevTp, &prevTpLen) && + GetLastError() == ERROR_SUCCESS) { + // Round up size to full pages and allocate + allocSize = (allocSize + largePageSize - 1) & ~size_t(largePageSize - 1); + mem = VirtualAlloc(nullptr, allocSize, MEM_RESERVE | MEM_COMMIT | MEM_LARGE_PAGES, + PAGE_READWRITE); + + // Privilege no longer needed, restore previous state + fun8( // AdjustTokenPrivileges () + hProcessToken, FALSE, &prevTp, 0, nullptr, nullptr); + } + } -void* aligned_large_pages_alloc(size_t allocSize) { + CloseHandle(hProcessToken); - // Try to allocate large pages - void* mem = aligned_large_pages_alloc_windows(allocSize); + return mem; - // Fall back to regular, page aligned, allocation if necessary - if (!mem) - mem = VirtualAlloc(nullptr, allocSize, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); + #endif + } - return mem; -} + void* aligned_large_pages_alloc(size_t allocSize) { -#else + // Try to allocate large pages + void* mem = aligned_large_pages_alloc_windows(allocSize); -void* aligned_large_pages_alloc(size_t allocSize) { + // Fall back to regular, page aligned, allocation if necessary + if (!mem) mem = VirtualAlloc(nullptr, allocSize, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); + + return mem; + } -#if defined(__linux__) - constexpr size_t alignment = 2 * 1024 * 1024; // assumed 2MB page size #else - constexpr size_t alignment = 4096; // assumed small page size -#endif - // round up to multiples of alignment - size_t size = ((allocSize + alignment - 1) / alignment) * alignment; - void *mem = std_aligned_alloc(alignment, size); -#if defined(MADV_HUGEPAGE) - madvise(mem, size, MADV_HUGEPAGE); -#endif - return mem; -} + void* aligned_large_pages_alloc(size_t allocSize) { + + #if defined(__linux__) + constexpr size_t alignment = 2 * 1024 * 1024; // assumed 2MB page size + #else + constexpr size_t alignment = 4096; // assumed small page size + #endif + + // round up to multiples of alignment + size_t size = ((allocSize + alignment - 1) / alignment) * alignment; + void* mem = std_aligned_alloc(alignment, size); + #if defined(MADV_HUGEPAGE) + madvise(mem, size, MADV_HUGEPAGE); + #endif + return mem; + } #endif -/// aligned_large_pages_free() will free the previously allocated ttmem + /// aligned_large_pages_free() will free the previously allocated ttmem #if defined(_WIN32) -void aligned_large_pages_free(void* mem) { + void aligned_large_pages_free(void* mem) { - if (mem && !VirtualFree(mem, 0, MEM_RELEASE)) - { - DWORD err = GetLastError(); - std::cerr << "Failed to free large page memory. Error code: 0x" - << std::hex << err - << std::dec << std::endl; - exit(EXIT_FAILURE); - } -} + if (mem && !VirtualFree(mem, 0, MEM_RELEASE)) { + DWORD err = GetLastError(); + std::cerr << "Failed to free large page memory. Error code: 0x" << std::hex << err + << std::dec << std::endl; + exit(EXIT_FAILURE); + } + } #else -void aligned_large_pages_free(void *mem) { - std_aligned_free(mem); -} + void aligned_large_pages_free(void* mem) { std_aligned_free(mem); } #endif -namespace WinProcGroup { + namespace WinProcGroup { #ifndef _WIN32 -void bindThisThread(size_t) {} + void bindThisThread(size_t) {} #else -/// best_node() retrieves logical processor information using Windows specific -/// API and returns the best node id for the thread with index idx. Original -/// code from Texel by Peter Österlund. - -static int best_node(size_t idx) { - - int threads = 0; - int nodes = 0; - int cores = 0; - DWORD returnLength = 0; - DWORD byteOffset = 0; - - // Early exit if the needed API is not available at runtime - HMODULE k32 = GetModuleHandle(TEXT("Kernel32.dll")); - auto fun1 = (fun1_t)(void(*)())GetProcAddress(k32, "GetLogicalProcessorInformationEx"); - if (!fun1) - return -1; - - // First call to GetLogicalProcessorInformationEx() to get returnLength. - // We expect the call to fail due to null buffer. - if (fun1(RelationAll, nullptr, &returnLength)) - return -1; - - // Once we know returnLength, allocate the buffer - SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX *buffer, *ptr; - ptr = buffer = (SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX*)malloc(returnLength); - - // Second call to GetLogicalProcessorInformationEx(), now we expect to succeed - if (!fun1(RelationAll, buffer, &returnLength)) - { - free(buffer); - return -1; - } - - while (byteOffset < returnLength) - { - if (ptr->Relationship == RelationNumaNode) - nodes++; - - else if (ptr->Relationship == RelationProcessorCore) - { - cores++; - threads += (ptr->Processor.Flags == LTP_PC_SMT) ? 2 : 1; - } - - assert(ptr->Size); - byteOffset += ptr->Size; - ptr = (SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX*)(((char*)ptr) + ptr->Size); - } - - free(buffer); - - std::vector groups; - - // Run as many threads as possible on the same node until core limit is - // reached, then move on filling the next node. - for (int n = 0; n < nodes; n++) - for (int i = 0; i < cores / nodes; i++) - groups.push_back(n); - - // In case a core has more than one logical processor (we assume 2) and we - // have still threads to allocate, then spread them evenly across available - // nodes. - for (int t = 0; t < threads - cores; t++) - groups.push_back(t % nodes); - - // If we still have more threads than the total number of logical processors - // then return -1 and let the OS to decide what to do. - return idx < groups.size() ? groups[idx] : -1; -} + /// best_node() retrieves logical processor information using Windows specific + /// API and returns the best node id for the thread with index idx. Original + /// code from Texel by Peter Österlund. + + static int best_node(size_t idx) { + + int threads = 0; + int nodes = 0; + int cores = 0; + DWORD returnLength = 0; + DWORD byteOffset = 0; + + // Early exit if the needed API is not available at runtime + HMODULE k32 = GetModuleHandle(TEXT("Kernel32.dll")); + auto fun1 = + (fun1_t) (void (*)()) GetProcAddress(k32, "GetLogicalProcessorInformationEx"); + if (!fun1) return -1; + + // First call to GetLogicalProcessorInformationEx() to get returnLength. + // We expect the call to fail due to null buffer. + if (fun1(RelationAll, nullptr, &returnLength)) return -1; + + // Once we know returnLength, allocate the buffer + SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX *buffer, *ptr; + ptr = buffer = (SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX*) malloc(returnLength); + + // Second call to GetLogicalProcessorInformationEx(), now we expect to succeed + if (!fun1(RelationAll, buffer, &returnLength)) { + free(buffer); + return -1; + } + + while (byteOffset < returnLength) { + if (ptr->Relationship == RelationNumaNode) + nodes++; + + else if (ptr->Relationship == RelationProcessorCore) { + cores++; + threads += (ptr->Processor.Flags == LTP_PC_SMT) ? 2 : 1; + } + + assert(ptr->Size); + byteOffset += ptr->Size; + ptr = (SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX*) (((char*) ptr) + ptr->Size); + } + + free(buffer); + + std::vector groups; + + // Run as many threads as possible on the same node until core limit is + // reached, then move on filling the next node. + for (int n = 0; n < nodes; n++) + for (int i = 0; i < cores / nodes; i++) groups.push_back(n); + + // In case a core has more than one logical processor (we assume 2) and we + // have still threads to allocate, then spread them evenly across available + // nodes. + for (int t = 0; t < threads - cores; t++) groups.push_back(t % nodes); + + // If we still have more threads than the total number of logical processors + // then return -1 and let the OS to decide what to do. + return idx < groups.size() ? groups[idx] : -1; + } -/// bindThisThread() set the group affinity of the current thread - -void bindThisThread(size_t idx) { - - // Use only local variables to be thread-safe - int node = best_node(idx); - - if (node == -1) - return; - - // Early exit if the needed API are not available at runtime - HMODULE k32 = GetModuleHandle(TEXT("Kernel32.dll")); - auto fun2 = (fun2_t)(void(*)())GetProcAddress(k32, "GetNumaNodeProcessorMaskEx"); - auto fun3 = (fun3_t)(void(*)())GetProcAddress(k32, "SetThreadGroupAffinity"); - auto fun4 = (fun4_t)(void(*)())GetProcAddress(k32, "GetNumaNodeProcessorMask2"); - auto fun5 = (fun5_t)(void(*)())GetProcAddress(k32, "GetMaximumProcessorGroupCount"); - - if (!fun2 || !fun3) - return; - - if (!fun4 || !fun5) - { - GROUP_AFFINITY affinity; - if (fun2(node, &affinity)) // GetNumaNodeProcessorMaskEx - fun3(GetCurrentThread(), &affinity, nullptr); // SetThreadGroupAffinity - } - else - { - // If a numa node has more than one processor group, we assume they are - // sized equal and we spread threads evenly across the groups. - USHORT elements, returnedElements; - elements = fun5(); // GetMaximumProcessorGroupCount - GROUP_AFFINITY *affinity = (GROUP_AFFINITY*)malloc(elements * sizeof(GROUP_AFFINITY)); - if (fun4(node, affinity, elements, &returnedElements)) // GetNumaNodeProcessorMask2 - fun3(GetCurrentThread(), &affinity[idx % returnedElements], nullptr); // SetThreadGroupAffinity - free(affinity); - } -} + /// bindThisThread() set the group affinity of the current thread + + void bindThisThread(size_t idx) { + + // Use only local variables to be thread-safe + int node = best_node(idx); + + if (node == -1) return; + + // Early exit if the needed API are not available at runtime + HMODULE k32 = GetModuleHandle(TEXT("Kernel32.dll")); + auto fun2 = (fun2_t) (void (*)()) GetProcAddress(k32, "GetNumaNodeProcessorMaskEx"); + auto fun3 = (fun3_t) (void (*)()) GetProcAddress(k32, "SetThreadGroupAffinity"); + auto fun4 = (fun4_t) (void (*)()) GetProcAddress(k32, "GetNumaNodeProcessorMask2"); + auto fun5 = (fun5_t) (void (*)()) GetProcAddress(k32, "GetMaximumProcessorGroupCount"); + + if (!fun2 || !fun3) return; + + if (!fun4 || !fun5) { + GROUP_AFFINITY affinity; + if (fun2(node, &affinity)) // GetNumaNodeProcessorMaskEx + fun3(GetCurrentThread(), &affinity, nullptr); // SetThreadGroupAffinity + } else { + // If a numa node has more than one processor group, we assume they are + // sized equal and we spread threads evenly across the groups. + USHORT elements, returnedElements; + elements = fun5(); // GetMaximumProcessorGroupCount + GROUP_AFFINITY* affinity = + (GROUP_AFFINITY*) malloc(elements * sizeof(GROUP_AFFINITY)); + if (fun4(node, affinity, elements, &returnedElements)) // GetNumaNodeProcessorMask2 + fun3(GetCurrentThread(), &affinity[idx % returnedElements], + nullptr); // SetThreadGroupAffinity + free(affinity); + } + } #endif -} // namespace WinProcGroup + } // namespace WinProcGroup #ifdef _WIN32 -#include -#define GETCWD _getcwd + #include + #define GETCWD _getcwd #else -#include -#define GETCWD getcwd + #include + #define GETCWD getcwd #endif -namespace CommandLine { + namespace CommandLine { -std::string argv0; // path+name of the executable binary, as given by argv[0] -std::string binaryDirectory; // path of the executable directory -std::string workingDirectory; // path of the working directory + std::string argv0; // path+name of the executable binary, as given by argv[0] + std::string binaryDirectory; // path of the executable directory + std::string workingDirectory; // path of the working directory -void init([[maybe_unused]] int argc, char* argv[]) { - std::string pathSeparator; + void init([[maybe_unused]] int argc, char* argv[]) { + std::string pathSeparator; - // extract the path+name of the executable binary - argv0 = argv[0]; + // extract the path+name of the executable binary + argv0 = argv[0]; #ifdef _WIN32 - pathSeparator = "\\"; - #ifdef _MSC_VER - // Under windows argv[0] may not have the extension. Also _get_pgmptr() had - // issues in some windows 10 versions, so check returned values carefully. - char* pgmptr = nullptr; - if (!_get_pgmptr(&pgmptr) && pgmptr != nullptr && *pgmptr) - argv0 = pgmptr; - #endif + pathSeparator = "\\"; + #ifdef _MSC_VER + // Under windows argv[0] may not have the extension. Also _get_pgmptr() had + // issues in some windows 10 versions, so check returned values carefully. + char* pgmptr = nullptr; + if (!_get_pgmptr(&pgmptr) && pgmptr != nullptr && *pgmptr) argv0 = pgmptr; + #endif #else - pathSeparator = "/"; + pathSeparator = "/"; #endif - // extract the working directory - workingDirectory = ""; - char buff[40000]; - char* cwd = GETCWD(buff, 40000); - if (cwd) - workingDirectory = cwd; - - // extract the binary directory path from argv0 - binaryDirectory = argv0; - size_t pos = binaryDirectory.find_last_of("\\/"); - if (pos == std::string::npos) - binaryDirectory = "." + pathSeparator; - else - binaryDirectory.resize(pos + 1); - - // pattern replacement: "./" at the start of path is replaced by the working directory - if (binaryDirectory.find("." + pathSeparator) == 0) - binaryDirectory.replace(0, 1, workingDirectory); -} + // extract the working directory + workingDirectory = ""; + char buff[40000]; + char* cwd = GETCWD(buff, 40000); + if (cwd) workingDirectory = cwd; + + // extract the binary directory path from argv0 + binaryDirectory = argv0; + size_t pos = binaryDirectory.find_last_of("\\/"); + if (pos == std::string::npos) + binaryDirectory = "." + pathSeparator; + else + binaryDirectory.resize(pos + 1); + + // pattern replacement: "./" at the start of path is replaced by the working directory + if (binaryDirectory.find("." + pathSeparator) == 0) + binaryDirectory.replace(0, 1, workingDirectory); + } -} // namespace CommandLine + } // namespace CommandLine -} // namespace Stockfish +} // namespace Stockfish diff --git a/src/misc.h b/src/misc.h index aed677b5b29..6ef9bd00707 100644 --- a/src/misc.h +++ b/src/misc.h @@ -31,136 +31,141 @@ namespace Stockfish { -std::string engine_info(bool to_uci = false); -std::string compiler_info(); -void prefetch(void* addr); -void start_logger(const std::string& fname); -void* std_aligned_alloc(size_t alignment, size_t size); -void std_aligned_free(void* ptr); -void* aligned_large_pages_alloc(size_t size); // memory aligned by page size, min alignment: 4096 bytes -void aligned_large_pages_free(void* mem); // nop if mem == nullptr - -void dbg_hit_on(bool cond, int slot = 0); -void dbg_mean_of(int64_t value, int slot = 0); -void dbg_stdev_of(int64_t value, int slot = 0); -void dbg_correl_of(int64_t value1, int64_t value2, int slot = 0); -void dbg_print(); - -using TimePoint = std::chrono::milliseconds::rep; // A value in milliseconds -static_assert(sizeof(TimePoint) == sizeof(int64_t), "TimePoint should be 64 bits"); -inline TimePoint now() { - return std::chrono::duration_cast - (std::chrono::steady_clock::now().time_since_epoch()).count(); -} - - -enum SyncCout { IO_LOCK, IO_UNLOCK }; -std::ostream& operator<<(std::ostream&, SyncCout); + std::string engine_info(bool to_uci = false); + std::string compiler_info(); + void prefetch(void* addr); + void start_logger(const std::string& fname); + void* std_aligned_alloc(size_t alignment, size_t size); + void std_aligned_free(void* ptr); + void* aligned_large_pages_alloc( + size_t size); // memory aligned by page size, min alignment: 4096 bytes + void aligned_large_pages_free(void* mem); // nop if mem == nullptr + + void dbg_hit_on(bool cond, int slot = 0); + void dbg_mean_of(int64_t value, int slot = 0); + void dbg_stdev_of(int64_t value, int slot = 0); + void dbg_correl_of(int64_t value1, int64_t value2, int slot = 0); + void dbg_print(); + + using TimePoint = std::chrono::milliseconds::rep; // A value in milliseconds + static_assert(sizeof(TimePoint) == sizeof(int64_t), "TimePoint should be 64 bits"); + inline TimePoint now() { + return std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + } + + + enum SyncCout { + IO_LOCK, + IO_UNLOCK + }; + std::ostream& operator<<(std::ostream&, SyncCout); #define sync_cout std::cout << IO_LOCK #define sync_endl std::endl << IO_UNLOCK -// align_ptr_up() : get the first aligned element of an array. -// ptr must point to an array of size at least `sizeof(T) * N + alignment` bytes, -// where N is the number of elements in the array. -template -T* align_ptr_up(T* ptr) -{ - static_assert(alignof(T) < Alignment); + // align_ptr_up() : get the first aligned element of an array. + // ptr must point to an array of size at least `sizeof(T) * N + alignment` bytes, + // where N is the number of elements in the array. + template T* align_ptr_up(T* ptr) { + static_assert(alignof(T) < Alignment); - const uintptr_t ptrint = reinterpret_cast(reinterpret_cast(ptr)); - return reinterpret_cast(reinterpret_cast((ptrint + (Alignment - 1)) / Alignment * Alignment)); -} + const uintptr_t ptrint = reinterpret_cast(reinterpret_cast(ptr)); + return reinterpret_cast( + reinterpret_cast((ptrint + (Alignment - 1)) / Alignment * Alignment)); + } -// IsLittleEndian : true if and only if the binary is compiled on a little endian machine -static inline const union { uint32_t i; char c[4]; } Le = { 0x01020304 }; -static inline const bool IsLittleEndian = (Le.c[0] == 4); + // IsLittleEndian : true if and only if the binary is compiled on a little endian machine + static inline const union { + uint32_t i; + char c[4]; + } Le = {0x01020304}; + static inline const bool IsLittleEndian = (Le.c[0] == 4); -template -class ValueList { + template class ValueList { -public: - std::size_t size() const { return size_; } - void push_back(const T& value) { values_[size_++] = value; } - const T* begin() const { return values_; } - const T* end() const { return values_ + size_; } + public: + std::size_t size() const { return size_; } + void push_back(const T& value) { values_[size_++] = value; } + const T* begin() const { return values_; } + const T* end() const { return values_ + size_; } -private: - T values_[MaxSize]; - std::size_t size_ = 0; -}; + private: + T values_[MaxSize]; + std::size_t size_ = 0; + }; -/// xorshift64star Pseudo-Random Number Generator -/// This class is based on original code written and dedicated -/// to the public domain by Sebastiano Vigna (2014). -/// It has the following characteristics: -/// -/// - Outputs 64-bit numbers -/// - Passes Dieharder and SmallCrush test batteries -/// - Does not require warm-up, no zeroland to escape -/// - Internal state is a single 64-bit integer -/// - Period is 2^64 - 1 -/// - Speed: 1.60 ns/call (Core i7 @3.40GHz) -/// -/// For further analysis see -/// + /// xorshift64star Pseudo-Random Number Generator + /// This class is based on original code written and dedicated + /// to the public domain by Sebastiano Vigna (2014). + /// It has the following characteristics: + /// + /// - Outputs 64-bit numbers + /// - Passes Dieharder and SmallCrush test batteries + /// - Does not require warm-up, no zeroland to escape + /// - Internal state is a single 64-bit integer + /// - Period is 2^64 - 1 + /// - Speed: 1.60 ns/call (Core i7 @3.40GHz) + /// + /// For further analysis see + /// -class PRNG { + class PRNG { - uint64_t s; + uint64_t s; - uint64_t rand64() { + uint64_t rand64() { - s ^= s >> 12, s ^= s << 25, s ^= s >> 27; - return s * 2685821657736338717LL; - } + s ^= s >> 12, s ^= s << 25, s ^= s >> 27; + return s * 2685821657736338717LL; + } -public: - PRNG(uint64_t seed) : s(seed) { assert(seed); } + public: + PRNG(uint64_t seed) : s(seed) { assert(seed); } - template T rand() { return T(rand64()); } + template T rand() { return T(rand64()); } - /// Special generator used to fast init magic numbers. - /// Output values only have 1/8th of their bits set on average. - template T sparse_rand() - { return T(rand64() & rand64() & rand64()); } -}; + /// Special generator used to fast init magic numbers. + /// Output values only have 1/8th of their bits set on average. + template T sparse_rand() { return T(rand64() & rand64() & rand64()); } + }; -inline uint64_t mul_hi64(uint64_t a, uint64_t b) { + inline uint64_t mul_hi64(uint64_t a, uint64_t b) { #if defined(__GNUC__) && defined(IS_64BIT) - __extension__ using uint128 = unsigned __int128; - return ((uint128)a * (uint128)b) >> 64; + __extension__ using uint128 = unsigned __int128; + return ((uint128) a * (uint128) b) >> 64; #else - uint64_t aL = (uint32_t)a, aH = a >> 32; - uint64_t bL = (uint32_t)b, bH = b >> 32; - uint64_t c1 = (aL * bL) >> 32; - uint64_t c2 = aH * bL + c1; - uint64_t c3 = aL * bH + (uint32_t)c2; - return aH * bH + (c2 >> 32) + (c3 >> 32); + uint64_t aL = (uint32_t) a, aH = a >> 32; + uint64_t bL = (uint32_t) b, bH = b >> 32; + uint64_t c1 = (aL * bL) >> 32; + uint64_t c2 = aH * bL + c1; + uint64_t c3 = aL * bH + (uint32_t) c2; + return aH * bH + (c2 >> 32) + (c3 >> 32); #endif -} + } -/// Under Windows it is not possible for a process to run on more than one -/// logical processor group. This usually means to be limited to use max 64 -/// cores. To overcome this, some special platform specific API should be -/// called to set group affinity for each thread. Original code from Texel by -/// Peter Österlund. + /// Under Windows it is not possible for a process to run on more than one + /// logical processor group. This usually means to be limited to use max 64 + /// cores. To overcome this, some special platform specific API should be + /// called to set group affinity for each thread. Original code from Texel by + /// Peter Österlund. -namespace WinProcGroup { - void bindThisThread(size_t idx); -} + namespace WinProcGroup { + void bindThisThread(size_t idx); + } -namespace CommandLine { - void init(int argc, char* argv[]); + namespace CommandLine { + void init(int argc, char* argv[]); - extern std::string binaryDirectory; // path of the executable directory - extern std::string workingDirectory; // path of the working directory -} + extern std::string binaryDirectory; // path of the executable directory + extern std::string workingDirectory; // path of the working directory + } -} // namespace Stockfish +} // namespace Stockfish -#endif // #ifndef MISC_H_INCLUDED +#endif // #ifndef MISC_H_INCLUDED diff --git a/src/movegen.cpp b/src/movegen.cpp index f0733c73b66..a06e971921d 100644 --- a/src/movegen.cpp +++ b/src/movegen.cpp @@ -26,262 +26,235 @@ namespace Stockfish { -namespace { - - template - ExtMove* make_promotions(ExtMove* moveList, [[maybe_unused]] Square to) { - - if constexpr (Type == CAPTURES || Type == EVASIONS || Type == NON_EVASIONS) - { - *moveList++ = make(to - D, to, QUEEN); - if constexpr (Enemy && Type == CAPTURES) - { - *moveList++ = make(to - D, to, ROOK); - *moveList++ = make(to - D, to, BISHOP); - *moveList++ = make(to - D, to, KNIGHT); + namespace { + + template + ExtMove* make_promotions(ExtMove* moveList, [[maybe_unused]] Square to) { + + if constexpr (Type == CAPTURES || Type == EVASIONS || Type == NON_EVASIONS) { + *moveList++ = make(to - D, to, QUEEN); + if constexpr (Enemy && Type == CAPTURES) { + *moveList++ = make(to - D, to, ROOK); + *moveList++ = make(to - D, to, BISHOP); + *moveList++ = make(to - D, to, KNIGHT); + } + } + + if constexpr ((Type == QUIETS && !Enemy) || Type == EVASIONS || Type == NON_EVASIONS) { + *moveList++ = make(to - D, to, ROOK); + *moveList++ = make(to - D, to, BISHOP); + *moveList++ = make(to - D, to, KNIGHT); + } + + return moveList; } - } - if constexpr ((Type == QUIETS && !Enemy) || Type == EVASIONS || Type == NON_EVASIONS) - { - *moveList++ = make(to - D, to, ROOK); - *moveList++ = make(to - D, to, BISHOP); - *moveList++ = make(to - D, to, KNIGHT); - } - return moveList; - } + template + ExtMove* generate_pawn_moves(const Position& pos, ExtMove* moveList, Bitboard target) { + constexpr Color Them = ~Us; + constexpr Bitboard TRank7BB = (Us == WHITE ? Rank7BB : Rank2BB); + constexpr Bitboard TRank3BB = (Us == WHITE ? Rank3BB : Rank6BB); + constexpr Direction Up = pawn_push(Us); + constexpr Direction UpRight = (Us == WHITE ? NORTH_EAST : SOUTH_WEST); + constexpr Direction UpLeft = (Us == WHITE ? NORTH_WEST : SOUTH_EAST); - template - ExtMove* generate_pawn_moves(const Position& pos, ExtMove* moveList, Bitboard target) { + const Bitboard emptySquares = ~pos.pieces(); + const Bitboard enemies = Type == EVASIONS ? pos.checkers() : pos.pieces(Them); - constexpr Color Them = ~Us; - constexpr Bitboard TRank7BB = (Us == WHITE ? Rank7BB : Rank2BB); - constexpr Bitboard TRank3BB = (Us == WHITE ? Rank3BB : Rank6BB); - constexpr Direction Up = pawn_push(Us); - constexpr Direction UpRight = (Us == WHITE ? NORTH_EAST : SOUTH_WEST); - constexpr Direction UpLeft = (Us == WHITE ? NORTH_WEST : SOUTH_EAST); + Bitboard pawnsOn7 = pos.pieces(Us, PAWN) & TRank7BB; + Bitboard pawnsNotOn7 = pos.pieces(Us, PAWN) & ~TRank7BB; - const Bitboard emptySquares = ~pos.pieces(); - const Bitboard enemies = Type == EVASIONS ? pos.checkers() - : pos.pieces(Them); + // Single and double pawn pushes, no promotions + if constexpr (Type != CAPTURES) { + Bitboard b1 = shift(pawnsNotOn7) & emptySquares; + Bitboard b2 = shift(b1 & TRank3BB) & emptySquares; - Bitboard pawnsOn7 = pos.pieces(Us, PAWN) & TRank7BB; - Bitboard pawnsNotOn7 = pos.pieces(Us, PAWN) & ~TRank7BB; + if constexpr (Type == EVASIONS) // Consider only blocking squares + { + b1 &= target; + b2 &= target; + } - // Single and double pawn pushes, no promotions - if constexpr (Type != CAPTURES) - { - Bitboard b1 = shift(pawnsNotOn7) & emptySquares; - Bitboard b2 = shift(b1 & TRank3BB) & emptySquares; + if constexpr (Type == QUIET_CHECKS) { + // To make a quiet check, you either make a direct check by pushing a pawn + // or push a blocker pawn that is not on the same file as the enemy king. + // Discovered check promotion has been already generated amongst the captures. + Square ksq = pos.square(Them); + Bitboard dcCandidatePawns = pos.blockers_for_king(Them) & ~file_bb(ksq); + b1 &= pawn_attacks_bb(Them, ksq) | shift(dcCandidatePawns); + b2 &= pawn_attacks_bb(Them, ksq) | shift(dcCandidatePawns); + } - if constexpr (Type == EVASIONS) // Consider only blocking squares - { - b1 &= target; - b2 &= target; - } - - if constexpr (Type == QUIET_CHECKS) - { - // To make a quiet check, you either make a direct check by pushing a pawn - // or push a blocker pawn that is not on the same file as the enemy king. - // Discovered check promotion has been already generated amongst the captures. - Square ksq = pos.square(Them); - Bitboard dcCandidatePawns = pos.blockers_for_king(Them) & ~file_bb(ksq); - b1 &= pawn_attacks_bb(Them, ksq) | shift< Up>(dcCandidatePawns); - b2 &= pawn_attacks_bb(Them, ksq) | shift(dcCandidatePawns); - } + while (b1) { + Square to = pop_lsb(b1); + *moveList++ = make_move(to - Up, to); + } - while (b1) - { - Square to = pop_lsb(b1); - *moveList++ = make_move(to - Up, to); - } + while (b2) { + Square to = pop_lsb(b2); + *moveList++ = make_move(to - Up - Up, to); + } + } - while (b2) - { - Square to = pop_lsb(b2); - *moveList++ = make_move(to - Up - Up, to); - } - } + // Promotions and underpromotions + if (pawnsOn7) { + Bitboard b1 = shift(pawnsOn7) & enemies; + Bitboard b2 = shift(pawnsOn7) & enemies; + Bitboard b3 = shift(pawnsOn7) & emptySquares; - // Promotions and underpromotions - if (pawnsOn7) - { - Bitboard b1 = shift(pawnsOn7) & enemies; - Bitboard b2 = shift(pawnsOn7) & enemies; - Bitboard b3 = shift(pawnsOn7) & emptySquares; + if constexpr (Type == EVASIONS) b3 &= target; - if constexpr (Type == EVASIONS) - b3 &= target; + while (b1) moveList = make_promotions(moveList, pop_lsb(b1)); - while (b1) - moveList = make_promotions(moveList, pop_lsb(b1)); + while (b2) moveList = make_promotions(moveList, pop_lsb(b2)); - while (b2) - moveList = make_promotions(moveList, pop_lsb(b2)); + while (b3) moveList = make_promotions(moveList, pop_lsb(b3)); + } - while (b3) - moveList = make_promotions(moveList, pop_lsb(b3)); - } + // Standard and en passant captures + if constexpr (Type == CAPTURES || Type == EVASIONS || Type == NON_EVASIONS) { + Bitboard b1 = shift(pawnsNotOn7) & enemies; + Bitboard b2 = shift(pawnsNotOn7) & enemies; - // Standard and en passant captures - if constexpr (Type == CAPTURES || Type == EVASIONS || Type == NON_EVASIONS) - { - Bitboard b1 = shift(pawnsNotOn7) & enemies; - Bitboard b2 = shift(pawnsNotOn7) & enemies; + while (b1) { + Square to = pop_lsb(b1); + *moveList++ = make_move(to - UpRight, to); + } - while (b1) - { - Square to = pop_lsb(b1); - *moveList++ = make_move(to - UpRight, to); - } + while (b2) { + Square to = pop_lsb(b2); + *moveList++ = make_move(to - UpLeft, to); + } - while (b2) - { - Square to = pop_lsb(b2); - *moveList++ = make_move(to - UpLeft, to); - } + if (pos.ep_square() != SQ_NONE) { + assert(rank_of(pos.ep_square()) == relative_rank(Us, RANK_6)); - if (pos.ep_square() != SQ_NONE) - { - assert(rank_of(pos.ep_square()) == relative_rank(Us, RANK_6)); + // An en passant capture cannot resolve a discovered check + if (Type == EVASIONS && (target & (pos.ep_square() + Up))) return moveList; - // An en passant capture cannot resolve a discovered check - if (Type == EVASIONS && (target & (pos.ep_square() + Up))) - return moveList; + b1 = pawnsNotOn7 & pawn_attacks_bb(Them, pos.ep_square()); - b1 = pawnsNotOn7 & pawn_attacks_bb(Them, pos.ep_square()); + assert(b1); - assert(b1); + while (b1) *moveList++ = make(pop_lsb(b1), pos.ep_square()); + } + } - while (b1) - *moveList++ = make(pop_lsb(b1), pos.ep_square()); + return moveList; } - } - return moveList; - } + template + ExtMove* generate_moves(const Position& pos, ExtMove* moveList, Bitboard target) { - template - ExtMove* generate_moves(const Position& pos, ExtMove* moveList, Bitboard target) { + static_assert(Pt != KING && Pt != PAWN, "Unsupported piece type in generate_moves()"); - static_assert(Pt != KING && Pt != PAWN, "Unsupported piece type in generate_moves()"); + Bitboard bb = pos.pieces(Us, Pt); - Bitboard bb = pos.pieces(Us, Pt); + while (bb) { + Square from = pop_lsb(bb); + Bitboard b = attacks_bb(from, pos.pieces()) & target; - while (bb) - { - Square from = pop_lsb(bb); - Bitboard b = attacks_bb(from, pos.pieces()) & target; + // To check, you either move freely a blocker or make a direct check. + if (Checks && (Pt == QUEEN || !(pos.blockers_for_king(~Us) & from))) + b &= pos.check_squares(Pt); - // To check, you either move freely a blocker or make a direct check. - if (Checks && (Pt == QUEEN || !(pos.blockers_for_king(~Us) & from))) - b &= pos.check_squares(Pt); + while (b) *moveList++ = make_move(from, pop_lsb(b)); + } - while (b) - *moveList++ = make_move(from, pop_lsb(b)); - } - - return moveList; - } + return moveList; + } - template - ExtMove* generate_all(const Position& pos, ExtMove* moveList) { + template + ExtMove* generate_all(const Position& pos, ExtMove* moveList) { - static_assert(Type != LEGAL, "Unsupported type in generate_all()"); + static_assert(Type != LEGAL, "Unsupported type in generate_all()"); - constexpr bool Checks = Type == QUIET_CHECKS; // Reduce template instantiations - const Square ksq = pos.square(Us); - Bitboard target; + constexpr bool Checks = Type == QUIET_CHECKS; // Reduce template instantiations + const Square ksq = pos.square(Us); + Bitboard target; - // Skip generating non-king moves when in double check - if (Type != EVASIONS || !more_than_one(pos.checkers())) - { - target = Type == EVASIONS ? between_bb(ksq, lsb(pos.checkers())) - : Type == NON_EVASIONS ? ~pos.pieces( Us) - : Type == CAPTURES ? pos.pieces(~Us) - : ~pos.pieces( ); // QUIETS || QUIET_CHECKS + // Skip generating non-king moves when in double check + if (Type != EVASIONS || !more_than_one(pos.checkers())) { + target = Type == EVASIONS ? between_bb(ksq, lsb(pos.checkers())) : + Type == NON_EVASIONS ? ~pos.pieces(Us) : + Type == CAPTURES ? pos.pieces(~Us) : + ~pos.pieces(); // QUIETS || QUIET_CHECKS - moveList = generate_pawn_moves(pos, moveList, target); - moveList = generate_moves(pos, moveList, target); - moveList = generate_moves(pos, moveList, target); - moveList = generate_moves(pos, moveList, target); - moveList = generate_moves(pos, moveList, target); - } + moveList = generate_pawn_moves(pos, moveList, target); + moveList = generate_moves(pos, moveList, target); + moveList = generate_moves(pos, moveList, target); + moveList = generate_moves(pos, moveList, target); + moveList = generate_moves(pos, moveList, target); + } - if (!Checks || pos.blockers_for_king(~Us) & ksq) - { - Bitboard b = attacks_bb(ksq) & (Type == EVASIONS ? ~pos.pieces(Us) : target); - if (Checks) - b &= ~attacks_bb(pos.square(~Us)); + if (!Checks || pos.blockers_for_king(~Us) & ksq) { + Bitboard b = attacks_bb(ksq) & (Type == EVASIONS ? ~pos.pieces(Us) : target); + if (Checks) b &= ~attacks_bb(pos.square(~Us)); - while (b) - *moveList++ = make_move(ksq, pop_lsb(b)); + while (b) *moveList++ = make_move(ksq, pop_lsb(b)); - if ((Type == QUIETS || Type == NON_EVASIONS) && pos.can_castle(Us & ANY_CASTLING)) - for (CastlingRights cr : { Us & KING_SIDE, Us & QUEEN_SIDE } ) - if (!pos.castling_impeded(cr) && pos.can_castle(cr)) - *moveList++ = make(ksq, pos.castling_rook_square(cr)); - } + if ((Type == QUIETS || Type == NON_EVASIONS) && pos.can_castle(Us & ANY_CASTLING)) + for (CastlingRights cr : {Us & KING_SIDE, Us & QUEEN_SIDE}) + if (!pos.castling_impeded(cr) && pos.can_castle(cr)) + *moveList++ = make(ksq, pos.castling_rook_square(cr)); + } - return moveList; - } + return moveList; + } -} // namespace + } // namespace -/// Generates all pseudo-legal captures plus queen promotions -/// Generates all pseudo-legal non-captures and underpromotions -/// Generates all pseudo-legal check evasions when the side to move is in check -/// Generates all pseudo-legal non-captures giving check, except castling and promotions -/// Generates all pseudo-legal captures and non-captures -/// -/// Returns a pointer to the end of the move list. + /// Generates all pseudo-legal captures plus queen promotions + /// Generates all pseudo-legal non-captures and underpromotions + /// Generates all pseudo-legal check evasions when the side to move is in check + /// Generates all pseudo-legal non-captures giving check, except castling and promotions + /// Generates all pseudo-legal captures and non-captures + /// + /// Returns a pointer to the end of the move list. -template -ExtMove* generate(const Position& pos, ExtMove* moveList) { + template ExtMove* generate(const Position& pos, ExtMove* moveList) { - static_assert(Type != LEGAL, "Unsupported type in generate()"); - assert((Type == EVASIONS) == (bool)pos.checkers()); + static_assert(Type != LEGAL, "Unsupported type in generate()"); + assert((Type == EVASIONS) == (bool) pos.checkers()); - Color us = pos.side_to_move(); + Color us = pos.side_to_move(); - return us == WHITE ? generate_all(pos, moveList) - : generate_all(pos, moveList); -} + return us == WHITE ? generate_all(pos, moveList) : + generate_all(pos, moveList); + } -// Explicit template instantiations -template ExtMove* generate(const Position&, ExtMove*); -template ExtMove* generate(const Position&, ExtMove*); -template ExtMove* generate(const Position&, ExtMove*); -template ExtMove* generate(const Position&, ExtMove*); -template ExtMove* generate(const Position&, ExtMove*); + // Explicit template instantiations + template ExtMove* generate(const Position&, ExtMove*); + template ExtMove* generate(const Position&, ExtMove*); + template ExtMove* generate(const Position&, ExtMove*); + template ExtMove* generate(const Position&, ExtMove*); + template ExtMove* generate(const Position&, ExtMove*); -/// generate generates all the legal moves in the given position + /// generate generates all the legal moves in the given position -template<> -ExtMove* generate(const Position& pos, ExtMove* moveList) { + template<> ExtMove* generate(const Position& pos, ExtMove* moveList) { - Color us = pos.side_to_move(); - Bitboard pinned = pos.blockers_for_king(us) & pos.pieces(us); - Square ksq = pos.square(us); - ExtMove* cur = moveList; + Color us = pos.side_to_move(); + Bitboard pinned = pos.blockers_for_king(us) & pos.pieces(us); + Square ksq = pos.square(us); + ExtMove* cur = moveList; - moveList = pos.checkers() ? generate(pos, moveList) - : generate(pos, moveList); - while (cur != moveList) - if ( ((pinned & from_sq(*cur)) || from_sq(*cur) == ksq || type_of(*cur) == EN_PASSANT) - && !pos.legal(*cur)) - *cur = (--moveList)->move; - else - ++cur; + moveList = pos.checkers() ? generate(pos, moveList) : + generate(pos, moveList); + while (cur != moveList) + if (((pinned & from_sq(*cur)) || from_sq(*cur) == ksq || type_of(*cur) == EN_PASSANT) && + !pos.legal(*cur)) + *cur = (--moveList)->move; + else + ++cur; - return moveList; -} + return moveList; + } -} // namespace Stockfish +} // namespace Stockfish diff --git a/src/movegen.h b/src/movegen.h index 6449de25794..6b949dcb0ad 100644 --- a/src/movegen.h +++ b/src/movegen.h @@ -26,53 +26,47 @@ namespace Stockfish { -class Position; - -enum GenType { - CAPTURES, - QUIETS, - QUIET_CHECKS, - EVASIONS, - NON_EVASIONS, - LEGAL -}; - -struct ExtMove { - Move move; - int value; - - operator Move() const { return move; } - void operator=(Move m) { move = m; } - - // Inhibit unwanted implicit conversions to Move - // with an ambiguity that yields to a compile error. - operator float() const = delete; -}; - -inline bool operator<(const ExtMove& f, const ExtMove& s) { - return f.value < s.value; -} - -template -ExtMove* generate(const Position& pos, ExtMove* moveList); - -/// The MoveList struct is a simple wrapper around generate(). It sometimes comes -/// in handy to use this class instead of the low level generate() function. -template -struct MoveList { - - explicit MoveList(const Position& pos) : last(generate(pos, moveList)) {} - const ExtMove* begin() const { return moveList; } - const ExtMove* end() const { return last; } - size_t size() const { return last - moveList; } - bool contains(Move move) const { - return std::find(begin(), end(), move) != end(); - } - -private: - ExtMove moveList[MAX_MOVES], *last; -}; - -} // namespace Stockfish - -#endif // #ifndef MOVEGEN_H_INCLUDED + class Position; + + enum GenType { + CAPTURES, + QUIETS, + QUIET_CHECKS, + EVASIONS, + NON_EVASIONS, + LEGAL + }; + + struct ExtMove { + Move move; + int value; + + operator Move() const { return move; } + void operator=(Move m) { move = m; } + + // Inhibit unwanted implicit conversions to Move + // with an ambiguity that yields to a compile error. + operator float() const = delete; + }; + + inline bool operator<(const ExtMove& f, const ExtMove& s) { return f.value < s.value; } + + template ExtMove* generate(const Position& pos, ExtMove* moveList); + + /// The MoveList struct is a simple wrapper around generate(). It sometimes comes + /// in handy to use this class instead of the low level generate() function. + template struct MoveList { + + explicit MoveList(const Position& pos) : last(generate(pos, moveList)) {} + const ExtMove* begin() const { return moveList; } + const ExtMove* end() const { return last; } + size_t size() const { return last - moveList; } + bool contains(Move move) const { return std::find(begin(), end(), move) != end(); } + + private: + ExtMove moveList[MAX_MOVES], *last; + }; + +} // namespace Stockfish + +#endif // #ifndef MOVEGEN_H_INCLUDED diff --git a/src/movepick.cpp b/src/movepick.cpp index d4f8ab092a8..cdb92bf1290 100644 --- a/src/movepick.cpp +++ b/src/movepick.cpp @@ -28,297 +28,289 @@ namespace Stockfish { -namespace { - - enum Stages { - MAIN_TT, CAPTURE_INIT, GOOD_CAPTURE, REFUTATION, QUIET_INIT, QUIET, BAD_CAPTURE, - EVASION_TT, EVASION_INIT, EVASION, - PROBCUT_TT, PROBCUT_INIT, PROBCUT, - QSEARCH_TT, QCAPTURE_INIT, QCAPTURE, QCHECK_INIT, QCHECK - }; - - // partial_insertion_sort() sorts moves in descending order up to and including - // a given limit. The order of moves smaller than the limit is left unspecified. - void partial_insertion_sort(ExtMove* begin, ExtMove* end, int limit) { - - for (ExtMove *sortedEnd = begin, *p = begin + 1; p < end; ++p) - if (p->value >= limit) - { - ExtMove tmp = *p, *q; - *p = *++sortedEnd; - for (q = sortedEnd; q != begin && *(q - 1) < tmp; --q) - *q = *(q - 1); - *q = tmp; + namespace { + + enum Stages { + MAIN_TT, + CAPTURE_INIT, + GOOD_CAPTURE, + REFUTATION, + QUIET_INIT, + QUIET, + BAD_CAPTURE, + EVASION_TT, + EVASION_INIT, + EVASION, + PROBCUT_TT, + PROBCUT_INIT, + PROBCUT, + QSEARCH_TT, + QCAPTURE_INIT, + QCAPTURE, + QCHECK_INIT, + QCHECK + }; + + // partial_insertion_sort() sorts moves in descending order up to and including + // a given limit. The order of moves smaller than the limit is left unspecified. + void partial_insertion_sort(ExtMove* begin, ExtMove* end, int limit) { + + for (ExtMove *sortedEnd = begin, *p = begin + 1; p < end; ++p) + if (p->value >= limit) { + ExtMove tmp = *p, *q; + *p = *++sortedEnd; + for (q = sortedEnd; q != begin && *(q - 1) < tmp; --q) *q = *(q - 1); + *q = tmp; + } } - } - -} // namespace - - -/// Constructors of the MovePicker class. As arguments we pass information -/// to help it to return the (presumably) good moves first, to decide which -/// moves to return (in the quiescence search, for instance, we only want to -/// search captures, promotions, and some checks) and how important good move -/// ordering is at the current node. - -/// MovePicker constructor for the main search -MovePicker::MovePicker(const Position& p, Move ttm, Depth d, const ButterflyHistory* mh, - const CapturePieceToHistory* cph, - const PieceToHistory** ch, - Move cm, - const Move* killers) - : pos(p), mainHistory(mh), captureHistory(cph), continuationHistory(ch), - ttMove(ttm), refutations{{killers[0], 0}, {killers[1], 0}, {cm, 0}}, depth(d) -{ - assert(d > 0); - - stage = (pos.checkers() ? EVASION_TT : MAIN_TT) + - !(ttm && pos.pseudo_legal(ttm)); -} - -/// MovePicker constructor for quiescence search -MovePicker::MovePicker(const Position& p, Move ttm, Depth d, const ButterflyHistory* mh, - const CapturePieceToHistory* cph, - const PieceToHistory** ch, - Square rs) - : pos(p), mainHistory(mh), captureHistory(cph), continuationHistory(ch), ttMove(ttm), recaptureSquare(rs), depth(d) -{ - assert(d <= 0); - - stage = (pos.checkers() ? EVASION_TT : QSEARCH_TT) + - !( ttm - && pos.pseudo_legal(ttm)); -} - -/// MovePicker constructor for ProbCut: we generate captures with SEE greater -/// than or equal to the given threshold. -MovePicker::MovePicker(const Position& p, Move ttm, Value th, const CapturePieceToHistory* cph) - : pos(p), captureHistory(cph), ttMove(ttm), threshold(th) -{ - assert(!pos.checkers()); - - stage = PROBCUT_TT + !(ttm && pos.capture_stage(ttm) - && pos.pseudo_legal(ttm) - && pos.see_ge(ttm, threshold)); -} - -/// MovePicker::score() assigns a numerical value to each move in a list, used -/// for sorting. Captures are ordered by Most Valuable Victim (MVV), preferring -/// captures with a good history. Quiets moves are ordered using the history tables. -template -void MovePicker::score() { - - static_assert(Type == CAPTURES || Type == QUIETS || Type == EVASIONS, "Wrong type"); - - [[maybe_unused]] Bitboard threatenedByPawn, threatenedByMinor, threatenedByRook, threatenedPieces; - if constexpr (Type == QUIETS) - { - Color us = pos.side_to_move(); - - threatenedByPawn = pos.attacks_by(~us); - threatenedByMinor = pos.attacks_by(~us) | pos.attacks_by(~us) | threatenedByPawn; - threatenedByRook = pos.attacks_by(~us) | threatenedByMinor; - - // Pieces threatened by pieces of lesser material value - threatenedPieces = (pos.pieces(us, QUEEN) & threatenedByRook) - | (pos.pieces(us, ROOK) & threatenedByMinor) - | (pos.pieces(us, KNIGHT, BISHOP) & threatenedByPawn); - } - - for (auto& m : *this) - if constexpr (Type == CAPTURES) - m.value = (7 * int(PieceValue[pos.piece_on(to_sq(m))]) - + (*captureHistory)[pos.moved_piece(m)][to_sq(m)][type_of(pos.piece_on(to_sq(m)))]) / 16; - - else if constexpr (Type == QUIETS) - { - Piece pc = pos.moved_piece(m); - PieceType pt = type_of(pos.moved_piece(m)); - Square from = from_sq(m); - Square to = to_sq(m); - - // histories - m.value = 2 * (*mainHistory)[pos.side_to_move()][from_to(m)]; - m.value += 2 * (*continuationHistory[0])[pc][to]; - m.value += (*continuationHistory[1])[pc][to]; - m.value += (*continuationHistory[3])[pc][to]; - m.value += (*continuationHistory[5])[pc][to]; - - // bonus for checks - m.value += bool(pos.check_squares(pt) & to) * 16384; - - // bonus for escaping from capture - m.value += threatenedPieces & from ? - (pt == QUEEN && !(to & threatenedByRook) ? 50000 - : pt == ROOK && !(to & threatenedByMinor) ? 25000 - : !(to & threatenedByPawn) ? 15000 - : 0 ) - : 0 ; - - // malus for putting piece en prise - m.value -= !(threatenedPieces & from) ? - (pt == QUEEN ? bool(to & threatenedByRook) * 50000 - + bool(to & threatenedByMinor) * 10000 - + bool(to & threatenedByPawn) * 20000 - : pt == ROOK ? bool(to & threatenedByMinor) * 25000 - + bool(to & threatenedByPawn) * 10000 - : pt != PAWN ? bool(to & threatenedByPawn) * 15000 - : 0 ) - : 0 ; - } - - else // Type == EVASIONS - { - if (pos.capture_stage(m)) - m.value = PieceValue[pos.piece_on(to_sq(m))] - - Value(type_of(pos.moved_piece(m))) - + (1 << 28); - else - m.value = (*mainHistory)[pos.side_to_move()][from_to(m)] - + (*continuationHistory[0])[pos.moved_piece(m)][to_sq(m)]; - } -} - -/// MovePicker::select() returns the next move satisfying a predicate function. -/// It never returns the TT move. -template -Move MovePicker::select(Pred filter) { - - while (cur < endMoves) - { - if constexpr (T == Best) - std::swap(*cur, *std::max_element(cur, endMoves)); - - if (*cur != ttMove && filter()) - return *cur++; - - cur++; - } - return MOVE_NONE; -} - -/// MovePicker::next_move() is the most important method of the MovePicker class. It -/// returns a new pseudo-legal move every time it is called until there are no more -/// moves left, picking the move with the highest score from a list of generated moves. -Move MovePicker::next_move(bool skipQuiets) { + + } // namespace + + + /// Constructors of the MovePicker class. As arguments we pass information + /// to help it to return the (presumably) good moves first, to decide which + /// moves to return (in the quiescence search, for instance, we only want to + /// search captures, promotions, and some checks) and how important good move + /// ordering is at the current node. + + /// MovePicker constructor for the main search + MovePicker::MovePicker(const Position& p, Move ttm, Depth d, const ButterflyHistory* mh, + const CapturePieceToHistory* cph, const PieceToHistory** ch, Move cm, + const Move* killers) + : pos(p), mainHistory(mh), captureHistory(cph), continuationHistory(ch), ttMove(ttm), + refutations{{killers[0], 0}, {killers[1], 0}, {cm, 0}}, depth(d) { + assert(d > 0); + + stage = (pos.checkers() ? EVASION_TT : MAIN_TT) + !(ttm && pos.pseudo_legal(ttm)); + } + + /// MovePicker constructor for quiescence search + MovePicker::MovePicker(const Position& p, Move ttm, Depth d, const ButterflyHistory* mh, + const CapturePieceToHistory* cph, const PieceToHistory** ch, Square rs) + : pos(p), mainHistory(mh), captureHistory(cph), continuationHistory(ch), ttMove(ttm), + recaptureSquare(rs), depth(d) { + assert(d <= 0); + + stage = (pos.checkers() ? EVASION_TT : QSEARCH_TT) + !(ttm && pos.pseudo_legal(ttm)); + } + + /// MovePicker constructor for ProbCut: we generate captures with SEE greater + /// than or equal to the given threshold. + MovePicker::MovePicker(const Position& p, Move ttm, Value th, const CapturePieceToHistory* cph) + : pos(p), captureHistory(cph), ttMove(ttm), threshold(th) { + assert(!pos.checkers()); + + stage = PROBCUT_TT + !(ttm && pos.capture_stage(ttm) && pos.pseudo_legal(ttm) && + pos.see_ge(ttm, threshold)); + } + + /// MovePicker::score() assigns a numerical value to each move in a list, used + /// for sorting. Captures are ordered by Most Valuable Victim (MVV), preferring + /// captures with a good history. Quiets moves are ordered using the history tables. + template void MovePicker::score() { + + static_assert(Type == CAPTURES || Type == QUIETS || Type == EVASIONS, "Wrong type"); + + [[maybe_unused]] Bitboard threatenedByPawn, threatenedByMinor, threatenedByRook, + threatenedPieces; + if constexpr (Type == QUIETS) { + Color us = pos.side_to_move(); + + threatenedByPawn = pos.attacks_by(~us); + threatenedByMinor = + pos.attacks_by(~us) | pos.attacks_by(~us) | threatenedByPawn; + threatenedByRook = pos.attacks_by(~us) | threatenedByMinor; + + // Pieces threatened by pieces of lesser material value + threatenedPieces = (pos.pieces(us, QUEEN) & threatenedByRook) | + (pos.pieces(us, ROOK) & threatenedByMinor) | + (pos.pieces(us, KNIGHT, BISHOP) & threatenedByPawn); + } + + for (auto& m : *this) + if constexpr (Type == CAPTURES) + m.value = (7 * int(PieceValue[pos.piece_on(to_sq(m))]) + + (*captureHistory)[pos.moved_piece(m)][to_sq(m)] + [type_of(pos.piece_on(to_sq(m)))]) / + 16; + + else if constexpr (Type == QUIETS) { + Piece pc = pos.moved_piece(m); + PieceType pt = type_of(pos.moved_piece(m)); + Square from = from_sq(m); + Square to = to_sq(m); + + // histories + m.value = 2 * (*mainHistory)[pos.side_to_move()][from_to(m)]; + m.value += 2 * (*continuationHistory[0])[pc][to]; + m.value += (*continuationHistory[1])[pc][to]; + m.value += (*continuationHistory[3])[pc][to]; + m.value += (*continuationHistory[5])[pc][to]; + + // bonus for checks + m.value += bool(pos.check_squares(pt) & to) * 16384; + + // bonus for escaping from capture + m.value += threatenedPieces & from ? + (pt == QUEEN && !(to & threatenedByRook) ? 50000 : + pt == ROOK && !(to & threatenedByMinor) ? 25000 : + !(to & threatenedByPawn) ? 15000 : + 0) : + 0; + + // malus for putting piece en prise + m.value -= !(threatenedPieces & from) ? + (pt == QUEEN ? bool(to & threatenedByRook) * 50000 + + bool(to & threatenedByMinor) * 10000 + + bool(to & threatenedByPawn) * 20000 : + pt == ROOK ? bool(to & threatenedByMinor) * 25000 + + bool(to & threatenedByPawn) * 10000 : + pt != PAWN ? bool(to & threatenedByPawn) * 15000 : + 0) : + 0; + } + + else // Type == EVASIONS + { + if (pos.capture_stage(m)) + m.value = PieceValue[pos.piece_on(to_sq(m))] - + Value(type_of(pos.moved_piece(m))) + (1 << 28); + else + m.value = (*mainHistory)[pos.side_to_move()][from_to(m)] + + (*continuationHistory[0])[pos.moved_piece(m)][to_sq(m)]; + } + } + + /// MovePicker::select() returns the next move satisfying a predicate function. + /// It never returns the TT move. + template Move MovePicker::select(Pred filter) { + + while (cur < endMoves) { + if constexpr (T == Best) std::swap(*cur, *std::max_element(cur, endMoves)); + + if (*cur != ttMove && filter()) return *cur++; + + cur++; + } + return MOVE_NONE; + } + + /// MovePicker::next_move() is the most important method of the MovePicker class. It + /// returns a new pseudo-legal move every time it is called until there are no more + /// moves left, picking the move with the highest score from a list of generated moves. + Move MovePicker::next_move(bool skipQuiets) { top: - switch (stage) { - - case MAIN_TT: - case EVASION_TT: - case QSEARCH_TT: - case PROBCUT_TT: - ++stage; - return ttMove; - - case CAPTURE_INIT: - case PROBCUT_INIT: - case QCAPTURE_INIT: - cur = endBadCaptures = moves; - endMoves = generate(pos, cur); - - score(); - partial_insertion_sort(cur, endMoves, std::numeric_limits::min()); - ++stage; - goto top; - - case GOOD_CAPTURE: - if (select([&](){ - return pos.see_ge(*cur, Value(-cur->value)) ? - // Move losing capture to endBadCaptures to be tried later - true : (*endBadCaptures++ = *cur, false); })) - return *(cur - 1); - - // Prepare the pointers to loop over the refutations array - cur = std::begin(refutations); - endMoves = std::end(refutations); - - // If the countermove is the same as a killer, skip it - if ( refutations[0].move == refutations[2].move - || refutations[1].move == refutations[2].move) - --endMoves; - - ++stage; - [[fallthrough]]; - - case REFUTATION: - if (select([&](){ return *cur != MOVE_NONE - && !pos.capture_stage(*cur) - && pos.pseudo_legal(*cur); })) - return *(cur - 1); - ++stage; - [[fallthrough]]; - - case QUIET_INIT: - if (!skipQuiets) - { - cur = endBadCaptures; - endMoves = generate(pos, cur); - - score(); - partial_insertion_sort(cur, endMoves, -3000 * depth); - } - - ++stage; - [[fallthrough]]; - - case QUIET: - if ( !skipQuiets - && select([&](){return *cur != refutations[0].move - && *cur != refutations[1].move - && *cur != refutations[2].move;})) - return *(cur - 1); - - // Prepare the pointers to loop over the bad captures - cur = moves; - endMoves = endBadCaptures; - - ++stage; - [[fallthrough]]; - - case BAD_CAPTURE: - return select([](){ return true; }); - - case EVASION_INIT: - cur = moves; - endMoves = generate(pos, cur); - - score(); - ++stage; - [[fallthrough]]; - - case EVASION: - return select([](){ return true; }); - - case PROBCUT: - return select([&](){ return pos.see_ge(*cur, threshold); }); - - case QCAPTURE: - if (select([&](){ return depth > DEPTH_QS_RECAPTURES - || to_sq(*cur) == recaptureSquare; })) - return *(cur - 1); - - // If we did not find any move and we do not try checks, we have finished - if (depth != DEPTH_QS_CHECKS) - return MOVE_NONE; - - ++stage; - [[fallthrough]]; - - case QCHECK_INIT: - cur = moves; - endMoves = generate(pos, cur); - - ++stage; - [[fallthrough]]; - - case QCHECK: - return select([](){ return true; }); - } - - assert(false); - return MOVE_NONE; // Silence warning -} - -} // namespace Stockfish + switch (stage) { + + case MAIN_TT : + case EVASION_TT : + case QSEARCH_TT : + case PROBCUT_TT : ++stage; return ttMove; + + case CAPTURE_INIT : + case PROBCUT_INIT : + case QCAPTURE_INIT : + cur = endBadCaptures = moves; + endMoves = generate(pos, cur); + + score(); + partial_insertion_sort(cur, endMoves, std::numeric_limits::min()); + ++stage; + goto top; + + case GOOD_CAPTURE : + if (select([&]() { + return pos.see_ge(*cur, Value(-cur->value)) ? + // Move losing capture to endBadCaptures to be tried later + true : + (*endBadCaptures++ = *cur, false); + })) + return *(cur - 1); + + // Prepare the pointers to loop over the refutations array + cur = std::begin(refutations); + endMoves = std::end(refutations); + + // If the countermove is the same as a killer, skip it + if (refutations[0].move == refutations[2].move || + refutations[1].move == refutations[2].move) + --endMoves; + + ++stage; + [[fallthrough]]; + + case REFUTATION : + if (select([&]() { + return *cur != MOVE_NONE && !pos.capture_stage(*cur) && pos.pseudo_legal(*cur); + })) + return *(cur - 1); + ++stage; + [[fallthrough]]; + + case QUIET_INIT : + if (!skipQuiets) { + cur = endBadCaptures; + endMoves = generate(pos, cur); + + score(); + partial_insertion_sort(cur, endMoves, -3000 * depth); + } + + ++stage; + [[fallthrough]]; + + case QUIET : + if (!skipQuiets && select([&]() { + return *cur != refutations[0].move && *cur != refutations[1].move && + *cur != refutations[2].move; + })) + return *(cur - 1); + + // Prepare the pointers to loop over the bad captures + cur = moves; + endMoves = endBadCaptures; + + ++stage; + [[fallthrough]]; + + case BAD_CAPTURE : return select([]() { return true; }); + + case EVASION_INIT : + cur = moves; + endMoves = generate(pos, cur); + + score(); + ++stage; + [[fallthrough]]; + + case EVASION : return select([]() { return true; }); + + case PROBCUT : return select([&]() { return pos.see_ge(*cur, threshold); }); + + case QCAPTURE : + if (select( + [&]() { return depth > DEPTH_QS_RECAPTURES || to_sq(*cur) == recaptureSquare; })) + return *(cur - 1); + + // If we did not find any move and we do not try checks, we have finished + if (depth != DEPTH_QS_CHECKS) return MOVE_NONE; + + ++stage; + [[fallthrough]]; + + case QCHECK_INIT : + cur = moves; + endMoves = generate(pos, cur); + + ++stage; + [[fallthrough]]; + + case QCHECK : return select([]() { return true; }); + } + + assert(false); + return MOVE_NONE; // Silence warning + } + +} // namespace Stockfish diff --git a/src/movepick.h b/src/movepick.h index 5243f89cf2c..c3c12b662f5 100644 --- a/src/movepick.h +++ b/src/movepick.h @@ -30,129 +30,130 @@ #include "types.h" namespace Stockfish { -class Position; - -/// StatsEntry stores the stat table value. It is usually a number but could -/// be a move or even a nested history. We use a class instead of naked value -/// to directly call history update operator<<() on the entry so to use stats -/// tables at caller sites as simple multi-dim arrays. -template -class StatsEntry { - - T entry; - -public: - void operator=(const T& v) { entry = v; } - T* operator&() { return &entry; } - T* operator->() { return &entry; } - operator const T&() const { return entry; } - - void operator<<(int bonus) { - assert(abs(bonus) <= D); // Ensure range is [-D, D] - static_assert(D <= std::numeric_limits::max(), "D overflows T"); - - entry += bonus - entry * abs(bonus) / D; - - assert(abs(entry) <= D); - } -}; - -/// Stats is a generic N-dimensional array used to store various statistics. -/// The first template parameter T is the base type of the array, the second -/// template parameter D limits the range of updates in [-D, D] when we update -/// values with the << operator, while the last parameters (Size and Sizes) -/// encode the dimensions of the array. -template -struct Stats : public std::array, Size> -{ - using stats = Stats; - - void fill(const T& v) { - - // For standard-layout 'this' points to first struct member - assert(std::is_standard_layout::value); - - using entry = StatsEntry; - entry* p = reinterpret_cast(this); - std::fill(p, p + sizeof(*this) / sizeof(entry), v); - } -}; - -template -struct Stats : public std::array, Size> {}; - -/// In stats table, D=0 means that the template parameter is not used -enum StatsParams { NOT_USED = 0 }; -enum StatsType { NoCaptures, Captures }; - -/// ButterflyHistory records how often quiet moves have been successful or -/// unsuccessful during the current search, and is used for reduction and move -/// ordering decisions. It uses 2 tables (one for each color) indexed by -/// the move's from and to squares, see www.chessprogramming.org/Butterfly_Boards -/// (~11 elo) -using ButterflyHistory = Stats; - -/// CounterMoveHistory stores counter moves indexed by [piece][to] of the previous -/// move, see www.chessprogramming.org/Countermove_Heuristic -using CounterMoveHistory = Stats; - -/// CapturePieceToHistory is addressed by a move's [piece][to][captured piece type] -using CapturePieceToHistory = Stats; - -/// PieceToHistory is like ButterflyHistory but is addressed by a move's [piece][to] -using PieceToHistory = Stats; - -/// ContinuationHistory is the combined history of a given pair of moves, usually -/// the current one given a previous one. The nested history table is based on -/// PieceToHistory instead of ButterflyBoards. -/// (~63 elo) -using ContinuationHistory = Stats; - - -/// MovePicker class is used to pick one pseudo-legal move at a time from the -/// current position. The most important method is next_move(), which returns a -/// new pseudo-legal move each time it is called, until there are no moves left, -/// when MOVE_NONE is returned. In order to improve the efficiency of the -/// alpha-beta algorithm, MovePicker attempts to return the moves which are most -/// likely to get a cut-off first. -class MovePicker { - - enum PickType { Next, Best }; - -public: - MovePicker(const MovePicker&) = delete; - MovePicker& operator=(const MovePicker&) = delete; - MovePicker(const Position&, Move, Depth, const ButterflyHistory*, - const CapturePieceToHistory*, - const PieceToHistory**, - Move, - const Move*); - MovePicker(const Position&, Move, Depth, const ButterflyHistory*, - const CapturePieceToHistory*, - const PieceToHistory**, - Square); - MovePicker(const Position&, Move, Value, const CapturePieceToHistory*); - Move next_move(bool skipQuiets = false); - -private: - template Move select(Pred); - template void score(); - ExtMove* begin() { return cur; } - ExtMove* end() { return endMoves; } - - const Position& pos; - const ButterflyHistory* mainHistory; - const CapturePieceToHistory* captureHistory; - const PieceToHistory** continuationHistory; - Move ttMove; - ExtMove refutations[3], *cur, *endMoves, *endBadCaptures; - int stage; - Square recaptureSquare; - Value threshold; - Depth depth; - ExtMove moves[MAX_MOVES]; -}; - -} // namespace Stockfish - -#endif // #ifndef MOVEPICK_H_INCLUDED + class Position; + + /// StatsEntry stores the stat table value. It is usually a number but could + /// be a move or even a nested history. We use a class instead of naked value + /// to directly call history update operator<<() on the entry so to use stats + /// tables at caller sites as simple multi-dim arrays. + template class StatsEntry { + + T entry; + + public: + void operator=(const T& v) { entry = v; } + T* operator&() { return &entry; } + T* operator->() { return &entry; } + operator const T&() const { return entry; } + + void operator<<(int bonus) { + assert(abs(bonus) <= D); // Ensure range is [-D, D] + static_assert(D <= std::numeric_limits::max(), "D overflows T"); + + entry += bonus - entry * abs(bonus) / D; + + assert(abs(entry) <= D); + } + }; + + /// Stats is a generic N-dimensional array used to store various statistics. + /// The first template parameter T is the base type of the array, the second + /// template parameter D limits the range of updates in [-D, D] when we update + /// values with the << operator, while the last parameters (Size and Sizes) + /// encode the dimensions of the array. + template struct Stats + : public std::array, Size> { + using stats = Stats; + + void fill(const T& v) { + + // For standard-layout 'this' points to first struct member + assert(std::is_standard_layout::value); + + using entry = StatsEntry; + entry* p = reinterpret_cast(this); + std::fill(p, p + sizeof(*this) / sizeof(entry), v); + } + }; + + template struct Stats + : public std::array, Size> {}; + + /// In stats table, D=0 means that the template parameter is not used + enum StatsParams { + NOT_USED = 0 + }; + enum StatsType { + NoCaptures, + Captures + }; + + /// ButterflyHistory records how often quiet moves have been successful or + /// unsuccessful during the current search, and is used for reduction and move + /// ordering decisions. It uses 2 tables (one for each color) indexed by + /// the move's from and to squares, see www.chessprogramming.org/Butterfly_Boards + /// (~11 elo) + using ButterflyHistory = Stats; + + /// CounterMoveHistory stores counter moves indexed by [piece][to] of the previous + /// move, see www.chessprogramming.org/Countermove_Heuristic + using CounterMoveHistory = Stats; + + /// CapturePieceToHistory is addressed by a move's [piece][to][captured piece type] + using CapturePieceToHistory = Stats; + + /// PieceToHistory is like ButterflyHistory but is addressed by a move's [piece][to] + using PieceToHistory = Stats; + + /// ContinuationHistory is the combined history of a given pair of moves, usually + /// the current one given a previous one. The nested history table is based on + /// PieceToHistory instead of ButterflyBoards. + /// (~63 elo) + using ContinuationHistory = Stats; + + + /// MovePicker class is used to pick one pseudo-legal move at a time from the + /// current position. The most important method is next_move(), which returns a + /// new pseudo-legal move each time it is called, until there are no moves left, + /// when MOVE_NONE is returned. In order to improve the efficiency of the + /// alpha-beta algorithm, MovePicker attempts to return the moves which are most + /// likely to get a cut-off first. + class MovePicker { + + enum PickType { + Next, + Best + }; + + public: + MovePicker(const MovePicker&) = delete; + MovePicker& operator=(const MovePicker&) = delete; + MovePicker(const Position&, Move, Depth, const ButterflyHistory*, + const CapturePieceToHistory*, const PieceToHistory**, Move, const Move*); + MovePicker(const Position&, Move, Depth, const ButterflyHistory*, + const CapturePieceToHistory*, const PieceToHistory**, Square); + MovePicker(const Position&, Move, Value, const CapturePieceToHistory*); + Move next_move(bool skipQuiets = false); + + private: + template Move select(Pred); + template void score(); + ExtMove* begin() { return cur; } + ExtMove* end() { return endMoves; } + + const Position& pos; + const ButterflyHistory* mainHistory; + const CapturePieceToHistory* captureHistory; + const PieceToHistory** continuationHistory; + Move ttMove; + ExtMove refutations[3], *cur, *endMoves, *endBadCaptures; + int stage; + Square recaptureSquare; + Value threshold; + Depth depth; + ExtMove moves[MAX_MOVES]; + }; + +} // namespace Stockfish + +#endif // #ifndef MOVEPICK_H_INCLUDED diff --git a/src/nnue/evaluate_nnue.cpp b/src/nnue/evaluate_nnue.cpp index 456f2edfdf3..7e71ce31bb9 100644 --- a/src/nnue/evaluate_nnue.cpp +++ b/src/nnue/evaluate_nnue.cpp @@ -39,371 +39,363 @@ namespace Stockfish::Eval::NNUE { - // Input feature converter - LargePagePtr featureTransformer; - - // Evaluation function - AlignedPtr network[LayerStacks]; - - // Evaluation function file name - std::string fileName; - std::string netDescription; - - namespace Detail { - - // Initialize the evaluation function parameters - template - void initialize(AlignedPtr& pointer) { - - pointer.reset(reinterpret_cast(std_aligned_alloc(alignof(T), sizeof(T)))); - std::memset(pointer.get(), 0, sizeof(T)); - } - - template - void initialize(LargePagePtr& pointer) { - - static_assert(alignof(T) <= 4096, "aligned_large_pages_alloc() may fail for such a big alignment requirement of T"); - pointer.reset(reinterpret_cast(aligned_large_pages_alloc(sizeof(T)))); - std::memset(pointer.get(), 0, sizeof(T)); - } - - // Read evaluation function parameters - template - bool read_parameters(std::istream& stream, T& reference) { - - std::uint32_t header; - header = read_little_endian(stream); - if (!stream || header != T::get_hash_value()) return false; - return reference.read_parameters(stream); - } - - // Write evaluation function parameters - template - bool write_parameters(std::ostream& stream, const T& reference) { - - write_little_endian(stream, T::get_hash_value()); - return reference.write_parameters(stream); - } - - } // namespace Detail - - - // Initialize the evaluation function parameters - static void initialize() { - - Detail::initialize(featureTransformer); - for (std::size_t i = 0; i < LayerStacks; ++i) - Detail::initialize(network[i]); - } - - // Read network header - static bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* desc) - { - std::uint32_t version, size; - - version = read_little_endian(stream); - *hashValue = read_little_endian(stream); - size = read_little_endian(stream); - if (!stream || version != Version) return false; - desc->resize(size); - stream.read(&(*desc)[0], size); - return !stream.fail(); - } - - // Write network header - static bool write_header(std::ostream& stream, std::uint32_t hashValue, const std::string& desc) - { - write_little_endian(stream, Version); - write_little_endian(stream, hashValue); - write_little_endian(stream, (std::uint32_t)desc.size()); - stream.write(&desc[0], desc.size()); - return !stream.fail(); - } - - // Read network parameters - static bool read_parameters(std::istream& stream) { - - std::uint32_t hashValue; - if (!read_header(stream, &hashValue, &netDescription)) return false; - if (hashValue != HashValue) return false; - if (!Detail::read_parameters(stream, *featureTransformer)) return false; - for (std::size_t i = 0; i < LayerStacks; ++i) - if (!Detail::read_parameters(stream, *(network[i]))) return false; - return stream && stream.peek() == std::ios::traits_type::eof(); - } - - // Write network parameters - static bool write_parameters(std::ostream& stream) { - - if (!write_header(stream, HashValue, netDescription)) return false; - if (!Detail::write_parameters(stream, *featureTransformer)) return false; - for (std::size_t i = 0; i < LayerStacks; ++i) - if (!Detail::write_parameters(stream, *(network[i]))) return false; - return (bool)stream; - } - - void hint_common_parent_position(const Position& pos) { - featureTransformer->hint_common_access(pos); - } - - // Evaluation function. Perform differential calculation. - Value evaluate(const Position& pos, bool adjusted, int* complexity) { - - // We manually align the arrays on the stack because with gcc < 9.3 - // overaligning stack variables with alignas() doesn't work correctly. - - constexpr uint64_t alignment = CacheLineSize; - constexpr int delta = 24; + // Input feature converter + LargePagePtr featureTransformer; -#if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN) - TransformedFeatureType transformedFeaturesUnaligned[ - FeatureTransformer::BufferSize + alignment / sizeof(TransformedFeatureType)]; + // Evaluation function + AlignedPtr network[LayerStacks]; - auto* transformedFeatures = align_ptr_up(&transformedFeaturesUnaligned[0]); -#else - alignas(alignment) - TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize]; -#endif + // Evaluation function file name + std::string fileName; + std::string netDescription; - ASSERT_ALIGNED(transformedFeatures, alignment); + namespace Detail { - const int bucket = (pos.count() - 1) / 4; - const auto psqt = featureTransformer->transform(pos, transformedFeatures, bucket); - const auto positional = network[bucket]->propagate(transformedFeatures); + // Initialize the evaluation function parameters + template void initialize(AlignedPtr& pointer) { - if (complexity) - *complexity = abs(psqt - positional) / OutputScale; + pointer.reset(reinterpret_cast(std_aligned_alloc(alignof(T), sizeof(T)))); + std::memset(pointer.get(), 0, sizeof(T)); + } - // Give more value to positional evaluation when adjusted flag is set - if (adjusted) - return static_cast(((1024 - delta) * psqt + (1024 + delta) * positional) / (1024 * OutputScale)); - else - return static_cast((psqt + positional) / OutputScale); - } + template void initialize(LargePagePtr& pointer) { - struct NnueEvalTrace { - static_assert(LayerStacks == PSQTBuckets); + static_assert( + alignof(T) <= 4096, + "aligned_large_pages_alloc() may fail for such a big alignment requirement of T"); + pointer.reset(reinterpret_cast(aligned_large_pages_alloc(sizeof(T)))); + std::memset(pointer.get(), 0, sizeof(T)); + } - Value psqt[LayerStacks]; - Value positional[LayerStacks]; - std::size_t correctBucket; - }; + // Read evaluation function parameters + template bool read_parameters(std::istream& stream, T& reference) { - static NnueEvalTrace trace_evaluate(const Position& pos) { + std::uint32_t header; + header = read_little_endian(stream); + if (!stream || header != T::get_hash_value()) return false; + return reference.read_parameters(stream); + } - // We manually align the arrays on the stack because with gcc < 9.3 - // overaligning stack variables with alignas() doesn't work correctly. - constexpr uint64_t alignment = CacheLineSize; + // Write evaluation function parameters + template bool write_parameters(std::ostream& stream, const T& reference) { -#if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN) - TransformedFeatureType transformedFeaturesUnaligned[ - FeatureTransformer::BufferSize + alignment / sizeof(TransformedFeatureType)]; + write_little_endian(stream, T::get_hash_value()); + return reference.write_parameters(stream); + } - auto* transformedFeatures = align_ptr_up(&transformedFeaturesUnaligned[0]); -#else - alignas(alignment) - TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize]; -#endif + } // namespace Detail - ASSERT_ALIGNED(transformedFeatures, alignment); - NnueEvalTrace t{}; - t.correctBucket = (pos.count() - 1) / 4; - for (IndexType bucket = 0; bucket < LayerStacks; ++bucket) { - const auto materialist = featureTransformer->transform(pos, transformedFeatures, bucket); - const auto positional = network[bucket]->propagate(transformedFeatures); + // Initialize the evaluation function parameters + static void initialize() { - t.psqt[bucket] = static_cast( materialist / OutputScale ); - t.positional[bucket] = static_cast( positional / OutputScale ); + Detail::initialize(featureTransformer); + for (std::size_t i = 0; i < LayerStacks; ++i) Detail::initialize(network[i]); } - return t; - } + // Read network header + static bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* desc) { + std::uint32_t version, size; + + version = read_little_endian(stream); + *hashValue = read_little_endian(stream); + size = read_little_endian(stream); + if (!stream || version != Version) return false; + desc->resize(size); + stream.read(&(*desc)[0], size); + return !stream.fail(); + } - constexpr std::string_view PieceToChar(" PNBRQK pnbrqk"); + // Write network header + static bool write_header(std::ostream& stream, std::uint32_t hashValue, + const std::string& desc) { + write_little_endian(stream, Version); + write_little_endian(stream, hashValue); + write_little_endian(stream, (std::uint32_t) desc.size()); + stream.write(&desc[0], desc.size()); + return !stream.fail(); + } + // Read network parameters + static bool read_parameters(std::istream& stream) { - // format_cp_compact() converts a Value into (centi)pawns and writes it in a buffer. - // The buffer must have capacity for at least 5 chars. - static void format_cp_compact(Value v, char* buffer) { + std::uint32_t hashValue; + if (!read_header(stream, &hashValue, &netDescription)) return false; + if (hashValue != HashValue) return false; + if (!Detail::read_parameters(stream, *featureTransformer)) return false; + for (std::size_t i = 0; i < LayerStacks; ++i) + if (!Detail::read_parameters(stream, *(network[i]))) return false; + return stream && stream.peek() == std::ios::traits_type::eof(); + } - buffer[0] = (v < 0 ? '-' : v > 0 ? '+' : ' '); + // Write network parameters + static bool write_parameters(std::ostream& stream) { - int cp = std::abs(UCI::to_cp(v)); - if (cp >= 10000) - { - buffer[1] = '0' + cp / 10000; cp %= 10000; - buffer[2] = '0' + cp / 1000; cp %= 1000; - buffer[3] = '0' + cp / 100; - buffer[4] = ' '; - } - else if (cp >= 1000) - { - buffer[1] = '0' + cp / 1000; cp %= 1000; - buffer[2] = '0' + cp / 100; cp %= 100; - buffer[3] = '.'; - buffer[4] = '0' + cp / 10; + if (!write_header(stream, HashValue, netDescription)) return false; + if (!Detail::write_parameters(stream, *featureTransformer)) return false; + for (std::size_t i = 0; i < LayerStacks; ++i) + if (!Detail::write_parameters(stream, *(network[i]))) return false; + return (bool) stream; } - else - { - buffer[1] = '0' + cp / 100; cp %= 100; - buffer[2] = '.'; - buffer[3] = '0' + cp / 10; cp %= 10; - buffer[4] = '0' + cp / 1; + + void hint_common_parent_position(const Position& pos) { + featureTransformer->hint_common_access(pos); } - } + // Evaluation function. Perform differential calculation. + Value evaluate(const Position& pos, bool adjusted, int* complexity) { - // format_cp_aligned_dot() converts a Value into pawns, always keeping two decimals - static void format_cp_aligned_dot(Value v, std::stringstream &stream) { + // We manually align the arrays on the stack because with gcc < 9.3 + // overaligning stack variables with alignas() doesn't work correctly. - const double pawns = std::abs(0.01 * UCI::to_cp(v)); + constexpr uint64_t alignment = CacheLineSize; + constexpr int delta = 24; - stream << (v < 0 ? '-' : v > 0 ? '+' : ' ') - << std::setiosflags(std::ios::fixed) - << std::setw(6) - << std::setprecision(2) - << pawns; - } +#if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN) + TransformedFeatureType + transformedFeaturesUnaligned[FeatureTransformer::BufferSize + + alignment / sizeof(TransformedFeatureType)]; + auto* transformedFeatures = align_ptr_up(&transformedFeaturesUnaligned[0]); +#else + alignas(alignment) + TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize]; +#endif - // trace() returns a string with the value of each piece on a board, - // and a table for (PSQT, Layers) values bucket by bucket. - std::string trace(Position& pos) { + ASSERT_ALIGNED(transformedFeatures, alignment); - std::stringstream ss; + const int bucket = (pos.count() - 1) / 4; + const auto psqt = featureTransformer->transform(pos, transformedFeatures, bucket); + const auto positional = network[bucket]->propagate(transformedFeatures); - char board[3*8+1][8*8+2]; - std::memset(board, ' ', sizeof(board)); - for (int row = 0; row < 3*8+1; ++row) - board[row][8*8+1] = '\0'; + if (complexity) *complexity = abs(psqt - positional) / OutputScale; - // A lambda to output one box of the board - auto writeSquare = [&board](File file, Rank rank, Piece pc, Value value) { + // Give more value to positional evaluation when adjusted flag is set + if (adjusted) + return static_cast(((1024 - delta) * psqt + (1024 + delta) * positional) / + (1024 * OutputScale)); + else + return static_cast((psqt + positional) / OutputScale); + } + + struct NnueEvalTrace { + static_assert(LayerStacks == PSQTBuckets); - const int x = ((int)file) * 8; - const int y = (7 - (int)rank) * 3; - for (int i = 1; i < 8; ++i) - board[y][x+i] = board[y+3][x+i] = '-'; - for (int i = 1; i < 3; ++i) - board[y+i][x] = board[y+i][x+8] = '|'; - board[y][x] = board[y][x+8] = board[y+3][x+8] = board[y+3][x] = '+'; - if (pc != NO_PIECE) - board[y+1][x+4] = PieceToChar[pc]; - if (value != VALUE_NONE) - format_cp_compact(value, &board[y+2][x+2]); + Value psqt[LayerStacks]; + Value positional[LayerStacks]; + std::size_t correctBucket; }; - // We estimate the value of each piece by doing a differential evaluation from - // the current base eval, simulating the removal of the piece from its square. - Value base = evaluate(pos); - base = pos.side_to_move() == WHITE ? base : -base; - - for (File f = FILE_A; f <= FILE_H; ++f) - for (Rank r = RANK_1; r <= RANK_8; ++r) - { - Square sq = make_square(f, r); - Piece pc = pos.piece_on(sq); - Value v = VALUE_NONE; - - if (pc != NO_PIECE && type_of(pc) != KING) - { - auto st = pos.state(); - - pos.remove_piece(sq); - st->accumulator.computed[WHITE] = false; - st->accumulator.computed[BLACK] = false; - - Value eval = evaluate(pos); - eval = pos.side_to_move() == WHITE ? eval : -eval; - v = base - eval; - - pos.put_piece(pc, sq); - st->accumulator.computed[WHITE] = false; - st->accumulator.computed[BLACK] = false; + static NnueEvalTrace trace_evaluate(const Position& pos) { + + // We manually align the arrays on the stack because with gcc < 9.3 + // overaligning stack variables with alignas() doesn't work correctly. + constexpr uint64_t alignment = CacheLineSize; + +#if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN) + TransformedFeatureType + transformedFeaturesUnaligned[FeatureTransformer::BufferSize + + alignment / sizeof(TransformedFeatureType)]; + + auto* transformedFeatures = align_ptr_up(&transformedFeaturesUnaligned[0]); +#else + alignas(alignment) + TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize]; +#endif + + ASSERT_ALIGNED(transformedFeatures, alignment); + + NnueEvalTrace t{}; + t.correctBucket = (pos.count() - 1) / 4; + for (IndexType bucket = 0; bucket < LayerStacks; ++bucket) { + const auto materialist = + featureTransformer->transform(pos, transformedFeatures, bucket); + const auto positional = network[bucket]->propagate(transformedFeatures); + + t.psqt[bucket] = static_cast(materialist / OutputScale); + t.positional[bucket] = static_cast(positional / OutputScale); } - writeSquare(f, r, pc, v); - } - - ss << " NNUE derived piece values:\n"; - for (int row = 0; row < 3*8+1; ++row) - ss << board[row] << '\n'; - ss << '\n'; - - auto t = trace_evaluate(pos); - - ss << " NNUE network contributions " - << (pos.side_to_move() == WHITE ? "(White to move)" : "(Black to move)") << std::endl - << "+------------+------------+------------+------------+\n" - << "| Bucket | Material | Positional | Total |\n" - << "| | (PSQT) | (Layers) | |\n" - << "+------------+------------+------------+------------+\n"; - - for (std::size_t bucket = 0; bucket < LayerStacks; ++bucket) - { - ss << "| " << bucket << " "; - ss << " | "; format_cp_aligned_dot(t.psqt[bucket], ss); ss << " " - << " | "; format_cp_aligned_dot(t.positional[bucket], ss); ss << " " - << " | "; format_cp_aligned_dot(t.psqt[bucket] + t.positional[bucket], ss); ss << " " - << " |"; - if (bucket == t.correctBucket) - ss << " <-- this bucket is used"; - ss << '\n'; + return t; + } + + constexpr std::string_view PieceToChar(" PNBRQK pnbrqk"); + + + // format_cp_compact() converts a Value into (centi)pawns and writes it in a buffer. + // The buffer must have capacity for at least 5 chars. + static void format_cp_compact(Value v, char* buffer) { + + buffer[0] = (v < 0 ? '-' : v > 0 ? '+' : ' '); + + int cp = std::abs(UCI::to_cp(v)); + if (cp >= 10000) { + buffer[1] = '0' + cp / 10000; + cp %= 10000; + buffer[2] = '0' + cp / 1000; + cp %= 1000; + buffer[3] = '0' + cp / 100; + buffer[4] = ' '; + } else if (cp >= 1000) { + buffer[1] = '0' + cp / 1000; + cp %= 1000; + buffer[2] = '0' + cp / 100; + cp %= 100; + buffer[3] = '.'; + buffer[4] = '0' + cp / 10; + } else { + buffer[1] = '0' + cp / 100; + cp %= 100; + buffer[2] = '.'; + buffer[3] = '0' + cp / 10; + cp %= 10; + buffer[4] = '0' + cp / 1; + } + } + + + // format_cp_aligned_dot() converts a Value into pawns, always keeping two decimals + static void format_cp_aligned_dot(Value v, std::stringstream& stream) { + + const double pawns = std::abs(0.01 * UCI::to_cp(v)); + + stream << (v < 0 ? '-' : + v > 0 ? '+' : + ' ') + << std::setiosflags(std::ios::fixed) << std::setw(6) << std::setprecision(2) + << pawns; } - ss << "+------------+------------+------------+------------+\n"; - return ss.str(); - } + // trace() returns a string with the value of each piece on a board, + // and a table for (PSQT, Layers) values bucket by bucket. + std::string trace(Position& pos) { + + std::stringstream ss; + + char board[3 * 8 + 1][8 * 8 + 2]; + std::memset(board, ' ', sizeof(board)); + for (int row = 0; row < 3 * 8 + 1; ++row) board[row][8 * 8 + 1] = '\0'; + + // A lambda to output one box of the board + auto writeSquare = [&board](File file, Rank rank, Piece pc, Value value) { + const int x = ((int) file) * 8; + const int y = (7 - (int) rank) * 3; + for (int i = 1; i < 8; ++i) board[y][x + i] = board[y + 3][x + i] = '-'; + for (int i = 1; i < 3; ++i) board[y + i][x] = board[y + i][x + 8] = '|'; + board[y][x] = board[y][x + 8] = board[y + 3][x + 8] = board[y + 3][x] = '+'; + if (pc != NO_PIECE) board[y + 1][x + 4] = PieceToChar[pc]; + if (value != VALUE_NONE) format_cp_compact(value, &board[y + 2][x + 2]); + }; + + // We estimate the value of each piece by doing a differential evaluation from + // the current base eval, simulating the removal of the piece from its square. + Value base = evaluate(pos); + base = pos.side_to_move() == WHITE ? base : -base; + + for (File f = FILE_A; f <= FILE_H; ++f) + for (Rank r = RANK_1; r <= RANK_8; ++r) { + Square sq = make_square(f, r); + Piece pc = pos.piece_on(sq); + Value v = VALUE_NONE; + + if (pc != NO_PIECE && type_of(pc) != KING) { + auto st = pos.state(); + + pos.remove_piece(sq); + st->accumulator.computed[WHITE] = false; + st->accumulator.computed[BLACK] = false; + + Value eval = evaluate(pos); + eval = pos.side_to_move() == WHITE ? eval : -eval; + v = base - eval; + + pos.put_piece(pc, sq); + st->accumulator.computed[WHITE] = false; + st->accumulator.computed[BLACK] = false; + } + + writeSquare(f, r, pc, v); + } + + ss << " NNUE derived piece values:\n"; + for (int row = 0; row < 3 * 8 + 1; ++row) ss << board[row] << '\n'; + ss << '\n'; + + auto t = trace_evaluate(pos); + + ss << " NNUE network contributions " + << (pos.side_to_move() == WHITE ? "(White to move)" : "(Black to move)") << std::endl + << "+------------+------------+------------+------------+\n" + << "| Bucket | Material | Positional | Total |\n" + << "| | (PSQT) | (Layers) | |\n" + << "+------------+------------+------------+------------+\n"; + + for (std::size_t bucket = 0; bucket < LayerStacks; ++bucket) { + ss << "| " << bucket << " "; + ss << " | "; + format_cp_aligned_dot(t.psqt[bucket], ss); + ss << " " + << " | "; + format_cp_aligned_dot(t.positional[bucket], ss); + ss << " " + << " | "; + format_cp_aligned_dot(t.psqt[bucket] + t.positional[bucket], ss); + ss << " " + << " |"; + if (bucket == t.correctBucket) ss << " <-- this bucket is used"; + ss << '\n'; + } + ss << "+------------+------------+------------+------------+\n"; - // Load eval, from a file stream or a memory stream - bool load_eval(std::string name, std::istream& stream) { + return ss.str(); + } - initialize(); - fileName = name; - return read_parameters(stream); - } - // Save eval, to a file stream or a memory stream - bool save_eval(std::ostream& stream) { + // Load eval, from a file stream or a memory stream + bool load_eval(std::string name, std::istream& stream) { - if (fileName.empty()) - return false; + initialize(); + fileName = name; + return read_parameters(stream); + } - return write_parameters(stream); - } + // Save eval, to a file stream or a memory stream + bool save_eval(std::ostream& stream) { - /// Save eval, to a file given by its name - bool save_eval(const std::optional& filename) { + if (fileName.empty()) return false; - std::string actualFilename; - std::string msg; + return write_parameters(stream); + } + + /// Save eval, to a file given by its name + bool save_eval(const std::optional& filename) { + + std::string actualFilename; + std::string msg; - if (filename.has_value()) - actualFilename = filename.value(); - else - { - if (currentEvalFileName != EvalFileDefaultName) - { - msg = "Failed to export a net. A non-embedded net can only be saved if the filename is specified"; + if (filename.has_value()) + actualFilename = filename.value(); + else { + if (currentEvalFileName != EvalFileDefaultName) { + msg = + "Failed to export a net. A non-embedded net can only be saved if the filename is specified"; - sync_cout << msg << sync_endl; - return false; + sync_cout << msg << sync_endl; + return false; + } + actualFilename = EvalFileDefaultName; } - actualFilename = EvalFileDefaultName; - } - std::ofstream stream(actualFilename, std::ios_base::binary); - bool saved = save_eval(stream); + std::ofstream stream(actualFilename, std::ios_base::binary); + bool saved = save_eval(stream); - msg = saved ? "Network saved successfully to " + actualFilename - : "Failed to export a net"; + msg = saved ? "Network saved successfully to " + actualFilename : "Failed to export a net"; - sync_cout << msg << sync_endl; - return saved; - } + sync_cout << msg << sync_endl; + return saved; + } -} // namespace Stockfish::Eval::NNUE +} // namespace Stockfish::Eval::NNUE diff --git a/src/nnue/evaluate_nnue.h b/src/nnue/evaluate_nnue.h index 8faec6cce43..64bcf75295c 100644 --- a/src/nnue/evaluate_nnue.h +++ b/src/nnue/evaluate_nnue.h @@ -32,48 +32,44 @@ #include "nnue_feature_transformer.h" namespace Stockfish { - class Position; - enum Value : int; + class Position; + enum Value : int; } namespace Stockfish::Eval::NNUE { - // Hash value of evaluation function structure - constexpr std::uint32_t HashValue = + // Hash value of evaluation function structure + constexpr std::uint32_t HashValue = FeatureTransformer::get_hash_value() ^ Network::get_hash_value(); - // Deleter for automating release of memory area - template - struct AlignedDeleter { - void operator()(T* ptr) const { - ptr->~T(); - std_aligned_free(ptr); - } - }; + // Deleter for automating release of memory area + template struct AlignedDeleter { + void operator()(T* ptr) const { + ptr->~T(); + std_aligned_free(ptr); + } + }; - template - struct LargePageDeleter { - void operator()(T* ptr) const { - ptr->~T(); - aligned_large_pages_free(ptr); - } - }; + template struct LargePageDeleter { + void operator()(T* ptr) const { + ptr->~T(); + aligned_large_pages_free(ptr); + } + }; - template - using AlignedPtr = std::unique_ptr>; + template using AlignedPtr = std::unique_ptr>; - template - using LargePagePtr = std::unique_ptr>; + template using LargePagePtr = std::unique_ptr>; - std::string trace(Position& pos); - Value evaluate(const Position& pos, bool adjusted = false, int* complexity = nullptr); - void hint_common_parent_position(const Position& pos); + std::string trace(Position& pos); + Value evaluate(const Position& pos, bool adjusted = false, int* complexity = nullptr); + void hint_common_parent_position(const Position& pos); - bool load_eval(std::string name, std::istream& stream); - bool save_eval(std::ostream& stream); - bool save_eval(const std::optional& filename); + bool load_eval(std::string name, std::istream& stream); + bool save_eval(std::ostream& stream); + bool save_eval(const std::optional& filename); } // namespace Stockfish::Eval::NNUE -#endif // #ifndef NNUE_EVALUATE_NNUE_H_INCLUDED +#endif // #ifndef NNUE_EVALUATE_NNUE_H_INCLUDED diff --git a/src/nnue/features/half_ka_v2_hm.cpp b/src/nnue/features/half_ka_v2_hm.cpp index 016934b8c4d..d3aeec9914c 100644 --- a/src/nnue/features/half_ka_v2_hm.cpp +++ b/src/nnue/features/half_ka_v2_hm.cpp @@ -27,61 +27,52 @@ namespace Stockfish::Eval::NNUE::Features { - // Index of a feature for a given king position and another piece on some square - template - inline IndexType HalfKAv2_hm::make_index(Square s, Piece pc, Square ksq) { - return IndexType((int(s) ^ OrientTBL[Perspective][ksq]) + PieceSquareIndex[Perspective][pc] + KingBuckets[Perspective][ksq]); - } - - // Get a list of indices for active features - template - void HalfKAv2_hm::append_active_indices( - const Position& pos, - IndexList& active - ) { - Square ksq = pos.square(Perspective); - Bitboard bb = pos.pieces(); - while (bb) - { - Square s = pop_lsb(bb); - active.push_back(make_index(s, pos.piece_on(s), ksq)); + // Index of a feature for a given king position and another piece on some square + template + inline IndexType HalfKAv2_hm::make_index(Square s, Piece pc, Square ksq) { + return IndexType((int(s) ^ OrientTBL[Perspective][ksq]) + + PieceSquareIndex[Perspective][pc] + KingBuckets[Perspective][ksq]); } - } - - // Explicit template instantiations - template void HalfKAv2_hm::append_active_indices(const Position& pos, IndexList& active); - template void HalfKAv2_hm::append_active_indices(const Position& pos, IndexList& active); - - // append_changed_indices() : get a list of indices for recently changed features - template - void HalfKAv2_hm::append_changed_indices( - Square ksq, - const DirtyPiece& dp, - IndexList& removed, - IndexList& added - ) { - for (int i = 0; i < dp.dirty_num; ++i) { - if (dp.from[i] != SQ_NONE) - removed.push_back(make_index(dp.from[i], dp.piece[i], ksq)); - if (dp.to[i] != SQ_NONE) - added.push_back(make_index(dp.to[i], dp.piece[i], ksq)); + + // Get a list of indices for active features + template + void HalfKAv2_hm::append_active_indices(const Position& pos, IndexList& active) { + Square ksq = pos.square(Perspective); + Bitboard bb = pos.pieces(); + while (bb) { + Square s = pop_lsb(bb); + active.push_back(make_index(s, pos.piece_on(s), ksq)); + } + } + + // Explicit template instantiations + template void HalfKAv2_hm::append_active_indices(const Position& pos, IndexList& active); + template void HalfKAv2_hm::append_active_indices(const Position& pos, IndexList& active); + + // append_changed_indices() : get a list of indices for recently changed features + template + void HalfKAv2_hm::append_changed_indices(Square ksq, const DirtyPiece& dp, IndexList& removed, + IndexList& added) { + for (int i = 0; i < dp.dirty_num; ++i) { + if (dp.from[i] != SQ_NONE) + removed.push_back(make_index(dp.from[i], dp.piece[i], ksq)); + if (dp.to[i] != SQ_NONE) + added.push_back(make_index(dp.to[i], dp.piece[i], ksq)); + } } - } - // Explicit template instantiations - template void HalfKAv2_hm::append_changed_indices(Square ksq, const DirtyPiece& dp, IndexList& removed, IndexList& added); - template void HalfKAv2_hm::append_changed_indices(Square ksq, const DirtyPiece& dp, IndexList& removed, IndexList& added); + // Explicit template instantiations + template void HalfKAv2_hm::append_changed_indices(Square ksq, const DirtyPiece& dp, + IndexList& removed, IndexList& added); + template void HalfKAv2_hm::append_changed_indices(Square ksq, const DirtyPiece& dp, + IndexList& removed, IndexList& added); - int HalfKAv2_hm::update_cost(const StateInfo* st) { - return st->dirtyPiece.dirty_num; - } + int HalfKAv2_hm::update_cost(const StateInfo* st) { return st->dirtyPiece.dirty_num; } - int HalfKAv2_hm::refresh_cost(const Position& pos) { - return pos.count(); - } + int HalfKAv2_hm::refresh_cost(const Position& pos) { return pos.count(); } - bool HalfKAv2_hm::requires_refresh(const StateInfo* st, Color perspective) { - return st->dirtyPiece.piece[0] == make_piece(perspective, KING); - } + bool HalfKAv2_hm::requires_refresh(const StateInfo* st, Color perspective) { + return st->dirtyPiece.piece[0] == make_piece(perspective, KING); + } } // namespace Stockfish::Eval::NNUE::Features diff --git a/src/nnue/features/half_ka_v2_hm.h b/src/nnue/features/half_ka_v2_hm.h index 9da1cc05531..eea1d3941a1 100644 --- a/src/nnue/features/half_ka_v2_hm.h +++ b/src/nnue/features/half_ka_v2_hm.h @@ -28,127 +28,109 @@ #include "../nnue_common.h" namespace Stockfish { - struct StateInfo; - class Position; + struct StateInfo; + class Position; } namespace Stockfish::Eval::NNUE::Features { - // Feature HalfKAv2_hm: Combination of the position of own king - // and the position of pieces. Position mirrored such that king always on e..h files. - class HalfKAv2_hm { - - // unique number for each piece type on each square - enum { - PS_NONE = 0, - PS_W_PAWN = 0, - PS_B_PAWN = 1 * SQUARE_NB, - PS_W_KNIGHT = 2 * SQUARE_NB, - PS_B_KNIGHT = 3 * SQUARE_NB, - PS_W_BISHOP = 4 * SQUARE_NB, - PS_B_BISHOP = 5 * SQUARE_NB, - PS_W_ROOK = 6 * SQUARE_NB, - PS_B_ROOK = 7 * SQUARE_NB, - PS_W_QUEEN = 8 * SQUARE_NB, - PS_B_QUEEN = 9 * SQUARE_NB, - PS_KING = 10 * SQUARE_NB, - PS_NB = 11 * SQUARE_NB - }; - - static constexpr IndexType PieceSquareIndex[COLOR_NB][PIECE_NB] = { - // convention: W - us, B - them - // viewed from other side, W and B are reversed - { PS_NONE, PS_W_PAWN, PS_W_KNIGHT, PS_W_BISHOP, PS_W_ROOK, PS_W_QUEEN, PS_KING, PS_NONE, - PS_NONE, PS_B_PAWN, PS_B_KNIGHT, PS_B_BISHOP, PS_B_ROOK, PS_B_QUEEN, PS_KING, PS_NONE }, - { PS_NONE, PS_B_PAWN, PS_B_KNIGHT, PS_B_BISHOP, PS_B_ROOK, PS_B_QUEEN, PS_KING, PS_NONE, - PS_NONE, PS_W_PAWN, PS_W_KNIGHT, PS_W_BISHOP, PS_W_ROOK, PS_W_QUEEN, PS_KING, PS_NONE } - }; - - // Index of a feature for a given king position and another piece on some square - template - static IndexType make_index(Square s, Piece pc, Square ksq); - - public: - // Feature name - static constexpr const char* Name = "HalfKAv2_hm(Friend)"; - - // Hash value embedded in the evaluation file - static constexpr std::uint32_t HashValue = 0x7f234cb8u; - - // Number of feature dimensions - static constexpr IndexType Dimensions = - static_cast(SQUARE_NB) * static_cast(PS_NB) / 2; + // Feature HalfKAv2_hm: Combination of the position of own king + // and the position of pieces. Position mirrored such that king always on e..h files. + class HalfKAv2_hm { + + // unique number for each piece type on each square + enum { + PS_NONE = 0, + PS_W_PAWN = 0, + PS_B_PAWN = 1 * SQUARE_NB, + PS_W_KNIGHT = 2 * SQUARE_NB, + PS_B_KNIGHT = 3 * SQUARE_NB, + PS_W_BISHOP = 4 * SQUARE_NB, + PS_B_BISHOP = 5 * SQUARE_NB, + PS_W_ROOK = 6 * SQUARE_NB, + PS_B_ROOK = 7 * SQUARE_NB, + PS_W_QUEEN = 8 * SQUARE_NB, + PS_B_QUEEN = 9 * SQUARE_NB, + PS_KING = 10 * SQUARE_NB, + PS_NB = 11 * SQUARE_NB + }; + + static constexpr IndexType PieceSquareIndex[COLOR_NB][PIECE_NB] = { + // convention: W - us, B - them + // viewed from other side, W and B are reversed + {PS_NONE, PS_W_PAWN, PS_W_KNIGHT, PS_W_BISHOP, PS_W_ROOK, PS_W_QUEEN, PS_KING, PS_NONE, + PS_NONE, PS_B_PAWN, PS_B_KNIGHT, PS_B_BISHOP, PS_B_ROOK, PS_B_QUEEN, PS_KING, PS_NONE}, + {PS_NONE, PS_B_PAWN, PS_B_KNIGHT, PS_B_BISHOP, PS_B_ROOK, PS_B_QUEEN, PS_KING, PS_NONE, + PS_NONE, PS_W_PAWN, PS_W_KNIGHT, PS_W_BISHOP, PS_W_ROOK, PS_W_QUEEN, PS_KING, PS_NONE}}; + + // Index of a feature for a given king position and another piece on some square + template static IndexType make_index(Square s, Piece pc, Square ksq); + + public: + // Feature name + static constexpr const char* Name = "HalfKAv2_hm(Friend)"; + + // Hash value embedded in the evaluation file + static constexpr std::uint32_t HashValue = 0x7f234cb8u; + + // Number of feature dimensions + static constexpr IndexType Dimensions = + static_cast(SQUARE_NB) * static_cast(PS_NB) / 2; #define B(v) (v * PS_NB) - static constexpr int KingBuckets[COLOR_NB][SQUARE_NB] = { - { B(28), B(29), B(30), B(31), B(31), B(30), B(29), B(28), - B(24), B(25), B(26), B(27), B(27), B(26), B(25), B(24), - B(20), B(21), B(22), B(23), B(23), B(22), B(21), B(20), - B(16), B(17), B(18), B(19), B(19), B(18), B(17), B(16), - B(12), B(13), B(14), B(15), B(15), B(14), B(13), B(12), - B( 8), B( 9), B(10), B(11), B(11), B(10), B( 9), B( 8), - B( 4), B( 5), B( 6), B( 7), B( 7), B( 6), B( 5), B( 4), - B( 0), B( 1), B( 2), B( 3), B( 3), B( 2), B( 1), B( 0) }, - { B( 0), B( 1), B( 2), B( 3), B( 3), B( 2), B( 1), B( 0), - B( 4), B( 5), B( 6), B( 7), B( 7), B( 6), B( 5), B( 4), - B( 8), B( 9), B(10), B(11), B(11), B(10), B( 9), B( 8), - B(12), B(13), B(14), B(15), B(15), B(14), B(13), B(12), - B(16), B(17), B(18), B(19), B(19), B(18), B(17), B(16), - B(20), B(21), B(22), B(23), B(23), B(22), B(21), B(20), - B(24), B(25), B(26), B(27), B(27), B(26), B(25), B(24), - B(28), B(29), B(30), B(31), B(31), B(30), B(29), B(28) } - }; + static constexpr int KingBuckets[COLOR_NB][SQUARE_NB] = { + {B(28), B(29), B(30), B(31), B(31), B(30), B(29), B(28), B(24), B(25), B(26), + B(27), B(27), B(26), B(25), B(24), B(20), B(21), B(22), B(23), B(23), B(22), + B(21), B(20), B(16), B(17), B(18), B(19), B(19), B(18), B(17), B(16), B(12), + B(13), B(14), B(15), B(15), B(14), B(13), B(12), B(8), B(9), B(10), B(11), + B(11), B(10), B(9), B(8), B(4), B(5), B(6), B(7), B(7), B(6), B(5), + B(4), B(0), B(1), B(2), B(3), B(3), B(2), B(1), B(0)}, + {B(0), B(1), B(2), B(3), B(3), B(2), B(1), B(0), B(4), B(5), B(6), + B(7), B(7), B(6), B(5), B(4), B(8), B(9), B(10), B(11), B(11), B(10), + B(9), B(8), B(12), B(13), B(14), B(15), B(15), B(14), B(13), B(12), B(16), + B(17), B(18), B(19), B(19), B(18), B(17), B(16), B(20), B(21), B(22), B(23), + B(23), B(22), B(21), B(20), B(24), B(25), B(26), B(27), B(27), B(26), B(25), + B(24), B(28), B(29), B(30), B(31), B(31), B(30), B(29), B(28)}}; #undef B - // Orient a square according to perspective (rotates by 180 for black) - static constexpr int OrientTBL[COLOR_NB][SQUARE_NB] = { - { SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1, - SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1, - SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1, - SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1, - SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1, - SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1, - SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1, - SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1 }, - { SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8, - SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8, - SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8, - SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8, - SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8, - SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8, - SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8, - SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8 } + // Orient a square according to perspective (rotates by 180 for black) + static constexpr int OrientTBL[COLOR_NB][SQUARE_NB] = { + {SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1, SQ_H1, SQ_H1, SQ_H1, + SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1, SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, + SQ_A1, SQ_A1, SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1, SQ_H1, + SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1, SQ_H1, SQ_H1, SQ_H1, SQ_H1, + SQ_A1, SQ_A1, SQ_A1, SQ_A1, SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, + SQ_A1, SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1}, + {SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8, SQ_H8, SQ_H8, SQ_H8, + SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8, SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, + SQ_A8, SQ_A8, SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8, SQ_H8, + SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8, SQ_H8, SQ_H8, SQ_H8, SQ_H8, + SQ_A8, SQ_A8, SQ_A8, SQ_A8, SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, + SQ_A8, SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8}}; + + // Maximum number of simultaneously active features. + static constexpr IndexType MaxActiveDimensions = 32; + using IndexList = ValueList; + + // Get a list of indices for active features + template + static void append_active_indices(const Position& pos, IndexList& active); + + // Get a list of indices for recently changed features + template + static void append_changed_indices(Square ksq, const DirtyPiece& dp, IndexList& removed, + IndexList& added); + + // Returns the cost of updating one perspective, the most costly one. + // Assumes no refresh needed. + static int update_cost(const StateInfo* st); + static int refresh_cost(const Position& pos); + + // Returns whether the change stored in this StateInfo means that + // a full accumulator refresh is required. + static bool requires_refresh(const StateInfo* st, Color perspective); }; - // Maximum number of simultaneously active features. - static constexpr IndexType MaxActiveDimensions = 32; - using IndexList = ValueList; - - // Get a list of indices for active features - template - static void append_active_indices( - const Position& pos, - IndexList& active); - - // Get a list of indices for recently changed features - template - static void append_changed_indices( - Square ksq, - const DirtyPiece& dp, - IndexList& removed, - IndexList& added - ); - - // Returns the cost of updating one perspective, the most costly one. - // Assumes no refresh needed. - static int update_cost(const StateInfo* st); - static int refresh_cost(const Position& pos); - - // Returns whether the change stored in this StateInfo means that - // a full accumulator refresh is required. - static bool requires_refresh(const StateInfo* st, Color perspective); - }; - } // namespace Stockfish::Eval::NNUE::Features -#endif // #ifndef NNUE_FEATURES_HALF_KA_V2_HM_H_INCLUDED +#endif // #ifndef NNUE_FEATURES_HALF_KA_V2_HM_H_INCLUDED diff --git a/src/nnue/layers/affine_transform.h b/src/nnue/layers/affine_transform.h index 61cdb781866..6113cece2a5 100644 --- a/src/nnue/layers/affine_transform.h +++ b/src/nnue/layers/affine_transform.h @@ -42,297 +42,283 @@ namespace Stockfish::Eval::NNUE::Layers { // Fallback implementation for older/other architectures. // Requires the input to be padded to at least 16 values. #if !defined(USE_SSSE3) - template - static void affine_transform_non_ssse3(std::int32_t* output, const std::int8_t* weights, const std::int32_t* biases, const std::uint8_t* input) - { -# if defined(USE_SSE2) - // At least a multiple of 16, with SSE2. - constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 16) / 16; - const __m128i Zeros = _mm_setzero_si128(); - const auto inputVector = reinterpret_cast(input); - -# elif defined(USE_MMX) - constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 8) / 8; - const __m64 Zeros = _mm_setzero_si64(); - const auto inputVector = reinterpret_cast(input); - -# elif defined(USE_NEON_DOTPROD) - constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 16) / 16; - const auto inputVector = reinterpret_cast(input); - -# elif defined(USE_NEON) - constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 16) / 16; - const auto inputVector = reinterpret_cast(input); -# endif - - for (IndexType i = 0; i < OutputDimensions; ++i) { - const IndexType offset = i * PaddedInputDimensions; - -# if defined(USE_SSE2) - __m128i sumLo = _mm_cvtsi32_si128(biases[i]); - __m128i sumHi = Zeros; - const auto row = reinterpret_cast(&weights[offset]); - for (IndexType j = 0; j < NumChunks; ++j) { - __m128i row_j = _mm_load_si128(&row[j]); - __m128i input_j = _mm_load_si128(&inputVector[j]); - __m128i extendedRowLo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8); - __m128i extendedRowHi = _mm_srai_epi16(_mm_unpackhi_epi8(row_j, row_j), 8); - __m128i extendedInputLo = _mm_unpacklo_epi8(input_j, Zeros); - __m128i extendedInputHi = _mm_unpackhi_epi8(input_j, Zeros); - __m128i productLo = _mm_madd_epi16(extendedRowLo, extendedInputLo); - __m128i productHi = _mm_madd_epi16(extendedRowHi, extendedInputHi); - sumLo = _mm_add_epi32(sumLo, productLo); - sumHi = _mm_add_epi32(sumHi, productHi); - } - __m128i sum = _mm_add_epi32(sumLo, sumHi); - __m128i sumHigh_64 = _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2)); - sum = _mm_add_epi32(sum, sumHigh_64); - __m128i sum_second_32 = _mm_shufflelo_epi16(sum, _MM_SHUFFLE(1, 0, 3, 2)); - sum = _mm_add_epi32(sum, sum_second_32); - output[i] = _mm_cvtsi128_si32(sum); - -# elif defined(USE_MMX) - __m64 sumLo = _mm_cvtsi32_si64(biases[i]); - __m64 sumHi = Zeros; - const auto row = reinterpret_cast(&weights[offset]); - for (IndexType j = 0; j < NumChunks; ++j) { - __m64 row_j = row[j]; - __m64 input_j = inputVector[j]; - __m64 extendedRowLo = _mm_srai_pi16(_mm_unpacklo_pi8(row_j, row_j), 8); - __m64 extendedRowHi = _mm_srai_pi16(_mm_unpackhi_pi8(row_j, row_j), 8); - __m64 extendedInputLo = _mm_unpacklo_pi8(input_j, Zeros); - __m64 extendedInputHi = _mm_unpackhi_pi8(input_j, Zeros); - __m64 productLo = _mm_madd_pi16(extendedRowLo, extendedInputLo); - __m64 productHi = _mm_madd_pi16(extendedRowHi, extendedInputHi); - sumLo = _mm_add_pi32(sumLo, productLo); - sumHi = _mm_add_pi32(sumHi, productHi); - } - __m64 sum = _mm_add_pi32(sumLo, sumHi); - sum = _mm_add_pi32(sum, _mm_unpackhi_pi32(sum, sum)); - output[i] = _mm_cvtsi64_si32(sum); - -# elif defined(USE_NEON_DOTPROD) - int32x4_t sum = {biases[i]}; - const auto row = reinterpret_cast(&weights[offset]); - for (IndexType j = 0; j < NumChunks; ++j) { - sum = vdotq_s32(sum, inputVector[j], row[j]); - } - output[i] = vaddvq_s32(sum); - -# elif defined(USE_NEON) - int32x4_t sum = {biases[i]}; - const auto row = reinterpret_cast(&weights[offset]); - for (IndexType j = 0; j < NumChunks; ++j) { - int16x8_t product = vmull_s8(inputVector[j * 2], row[j * 2]); - product = vmlal_s8(product, inputVector[j * 2 + 1], row[j * 2 + 1]); - sum = vpadalq_s16(sum, product); - } - output[i] = sum[0] + sum[1] + sum[2] + sum[3]; - -# else - std::int32_t sum = biases[i]; - for (IndexType j = 0; j < InputDimensions; ++j) { - sum += weights[offset + j] * input[j]; - } - output[i] = sum; -# endif - } - -# if defined(USE_MMX) - _mm_empty(); -# endif - } -#endif - - template - class AffineTransform { - public: - // Input/output type - using InputType = std::uint8_t; - using OutputType = std::int32_t; - - // Number of input/output dimensions - static constexpr IndexType InputDimensions = InDims; - static constexpr IndexType OutputDimensions = OutDims; - - static constexpr IndexType PaddedInputDimensions = - ceil_to_multiple(InputDimensions, MaxSimdWidth); - static constexpr IndexType PaddedOutputDimensions = - ceil_to_multiple(OutputDimensions, MaxSimdWidth); - - using OutputBuffer = OutputType[PaddedOutputDimensions]; - - // Hash value embedded in the evaluation file - static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) { - std::uint32_t hashValue = 0xCC03DAE4u; - hashValue += OutputDimensions; - hashValue ^= prevHash >> 1; - hashValue ^= prevHash << 31; - return hashValue; - } - - static constexpr IndexType get_weight_index_scrambled(IndexType i) - { - return - (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 + - i / PaddedInputDimensions * 4 + - i % 4; - } - - static constexpr IndexType get_weight_index(IndexType i) - { -#if defined (USE_SSSE3) - return get_weight_index_scrambled(i); -#else - return i; -#endif - } - - // Read network parameters - bool read_parameters(std::istream& stream) { - read_little_endian(stream, biases, OutputDimensions); - for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) - weights[get_weight_index(i)] = read_little_endian(stream); - - return !stream.fail(); - } - - // Write network parameters - bool write_parameters(std::ostream& stream) const { - write_little_endian(stream, biases, OutputDimensions); - - for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) - write_little_endian(stream, weights[get_weight_index(i)]); + template + static void affine_transform_non_ssse3(std::int32_t* output, const std::int8_t* weights, + const std::int32_t* biases, const std::uint8_t* input) { + #if defined(USE_SSE2) + // At least a multiple of 16, with SSE2. + constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 16) / 16; + const __m128i Zeros = _mm_setzero_si128(); + const auto inputVector = reinterpret_cast(input); + + #elif defined(USE_MMX) + constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 8) / 8; + const __m64 Zeros = _mm_setzero_si64(); + const auto inputVector = reinterpret_cast(input); + + #elif defined(USE_NEON_DOTPROD) + constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 16) / 16; + const auto inputVector = reinterpret_cast(input); + + #elif defined(USE_NEON) + constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 16) / 16; + const auto inputVector = reinterpret_cast(input); + #endif + + for (IndexType i = 0; i < OutputDimensions; ++i) { + const IndexType offset = i * PaddedInputDimensions; + + #if defined(USE_SSE2) + __m128i sumLo = _mm_cvtsi32_si128(biases[i]); + __m128i sumHi = Zeros; + const auto row = reinterpret_cast(&weights[offset]); + for (IndexType j = 0; j < NumChunks; ++j) { + __m128i row_j = _mm_load_si128(&row[j]); + __m128i input_j = _mm_load_si128(&inputVector[j]); + __m128i extendedRowLo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8); + __m128i extendedRowHi = _mm_srai_epi16(_mm_unpackhi_epi8(row_j, row_j), 8); + __m128i extendedInputLo = _mm_unpacklo_epi8(input_j, Zeros); + __m128i extendedInputHi = _mm_unpackhi_epi8(input_j, Zeros); + __m128i productLo = _mm_madd_epi16(extendedRowLo, extendedInputLo); + __m128i productHi = _mm_madd_epi16(extendedRowHi, extendedInputHi); + sumLo = _mm_add_epi32(sumLo, productLo); + sumHi = _mm_add_epi32(sumHi, productHi); + } + __m128i sum = _mm_add_epi32(sumLo, sumHi); + __m128i sumHigh_64 = _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2)); + sum = _mm_add_epi32(sum, sumHigh_64); + __m128i sum_second_32 = _mm_shufflelo_epi16(sum, _MM_SHUFFLE(1, 0, 3, 2)); + sum = _mm_add_epi32(sum, sum_second_32); + output[i] = _mm_cvtsi128_si32(sum); + + #elif defined(USE_MMX) + __m64 sumLo = _mm_cvtsi32_si64(biases[i]); + __m64 sumHi = Zeros; + const auto row = reinterpret_cast(&weights[offset]); + for (IndexType j = 0; j < NumChunks; ++j) { + __m64 row_j = row[j]; + __m64 input_j = inputVector[j]; + __m64 extendedRowLo = _mm_srai_pi16(_mm_unpacklo_pi8(row_j, row_j), 8); + __m64 extendedRowHi = _mm_srai_pi16(_mm_unpackhi_pi8(row_j, row_j), 8); + __m64 extendedInputLo = _mm_unpacklo_pi8(input_j, Zeros); + __m64 extendedInputHi = _mm_unpackhi_pi8(input_j, Zeros); + __m64 productLo = _mm_madd_pi16(extendedRowLo, extendedInputLo); + __m64 productHi = _mm_madd_pi16(extendedRowHi, extendedInputHi); + sumLo = _mm_add_pi32(sumLo, productLo); + sumHi = _mm_add_pi32(sumHi, productHi); + } + __m64 sum = _mm_add_pi32(sumLo, sumHi); + sum = _mm_add_pi32(sum, _mm_unpackhi_pi32(sum, sum)); + output[i] = _mm_cvtsi64_si32(sum); + + #elif defined(USE_NEON_DOTPROD) + int32x4_t sum = {biases[i]}; + const auto row = reinterpret_cast(&weights[offset]); + for (IndexType j = 0; j < NumChunks; ++j) { + sum = vdotq_s32(sum, inputVector[j], row[j]); + } + output[i] = vaddvq_s32(sum); + + #elif defined(USE_NEON) + int32x4_t sum = {biases[i]}; + const auto row = reinterpret_cast(&weights[offset]); + for (IndexType j = 0; j < NumChunks; ++j) { + int16x8_t product = vmull_s8(inputVector[j * 2], row[j * 2]); + product = vmlal_s8(product, inputVector[j * 2 + 1], row[j * 2 + 1]); + sum = vpadalq_s16(sum, product); + } + output[i] = sum[0] + sum[1] + sum[2] + sum[3]; + + #else + std::int32_t sum = biases[i]; + for (IndexType j = 0; j < InputDimensions; ++j) { + sum += weights[offset + j] * input[j]; + } + output[i] = sum; + #endif + } - return !stream.fail(); + #if defined(USE_MMX) + _mm_empty(); + #endif } - // Forward propagation - void propagate( - const InputType* input, OutputType* output) const { - -#if defined (USE_SSSE3) - - if constexpr (OutputDimensions > 1) - { - -#if defined (USE_AVX512) - using vec_t = __m512i; - #define vec_setzero _mm512_setzero_si512 - #define vec_set_32 _mm512_set1_epi32 - #define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32 - #define vec_add_dpbusd_32x2 Simd::m512_add_dpbusd_epi32x2 - #define vec_hadd Simd::m512_hadd -#elif defined (USE_AVX2) - using vec_t = __m256i; - #define vec_setzero _mm256_setzero_si256 - #define vec_set_32 _mm256_set1_epi32 - #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32 - #define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2 - #define vec_hadd Simd::m256_hadd -#elif defined (USE_SSSE3) - using vec_t = __m128i; - #define vec_setzero _mm_setzero_si128 - #define vec_set_32 _mm_set1_epi32 - #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32 - #define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2 - #define vec_hadd Simd::m128_hadd #endif - static constexpr IndexType OutputSimdWidth = sizeof(vec_t) / sizeof(OutputType); - - static_assert(OutputDimensions % OutputSimdWidth == 0); - - constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 8) / 4; - constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth; - - const auto input32 = reinterpret_cast(input); - const vec_t* biasvec = reinterpret_cast(biases); - vec_t acc[NumRegs]; - for (IndexType k = 0; k < NumRegs; ++k) - acc[k] = biasvec[k]; + template class AffineTransform { + public: + // Input/output type + using InputType = std::uint8_t; + using OutputType = std::int32_t; + + // Number of input/output dimensions + static constexpr IndexType InputDimensions = InDims; + static constexpr IndexType OutputDimensions = OutDims; + + static constexpr IndexType PaddedInputDimensions = + ceil_to_multiple(InputDimensions, MaxSimdWidth); + static constexpr IndexType PaddedOutputDimensions = + ceil_to_multiple(OutputDimensions, MaxSimdWidth); + + using OutputBuffer = OutputType[PaddedOutputDimensions]; + + // Hash value embedded in the evaluation file + static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) { + std::uint32_t hashValue = 0xCC03DAE4u; + hashValue += OutputDimensions; + hashValue ^= prevHash >> 1; + hashValue ^= prevHash << 31; + return hashValue; + } - for (IndexType i = 0; i < NumChunks; i += 2) - { - const vec_t in0 = vec_set_32(input32[i + 0]); - const vec_t in1 = vec_set_32(input32[i + 1]); - const auto col0 = reinterpret_cast(&weights[(i + 0) * OutputDimensions * 4]); - const auto col1 = reinterpret_cast(&weights[(i + 1) * OutputDimensions * 4]); - for (IndexType k = 0; k < NumRegs; ++k) - vec_add_dpbusd_32x2(acc[k], in0, col0[k], in1, col1[k]); + static constexpr IndexType get_weight_index_scrambled(IndexType i) { + return (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 + + i / PaddedInputDimensions * 4 + i % 4; } - vec_t* outptr = reinterpret_cast(output); - for (IndexType k = 0; k < NumRegs; ++k) - outptr[k] = acc[k]; - -# undef vec_setzero -# undef vec_set_32 -# undef vec_add_dpbusd_32 -# undef vec_add_dpbusd_32x2 -# undef vec_hadd - - } - else if constexpr (OutputDimensions == 1) - { - -// We cannot use AVX512 for the last layer because there's only 32 inputs and the buffer is not padded to 64 elements. -#if defined (USE_AVX2) - using vec_t = __m256i; - #define vec_setzero _mm256_setzero_si256 - #define vec_set_32 _mm256_set1_epi32 - #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32 - #define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2 - #define vec_hadd Simd::m256_hadd -#elif defined (USE_SSSE3) - using vec_t = __m128i; - #define vec_setzero _mm_setzero_si128 - #define vec_set_32 _mm_set1_epi32 - #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32 - #define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2 - #define vec_hadd Simd::m128_hadd + static constexpr IndexType get_weight_index(IndexType i) { +#if defined(USE_SSSE3) + return get_weight_index_scrambled(i); +#else + return i; #endif + } - const auto inputVector = reinterpret_cast(input); + // Read network parameters + bool read_parameters(std::istream& stream) { + read_little_endian(stream, biases, OutputDimensions); + for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) + weights[get_weight_index(i)] = read_little_endian(stream); - static constexpr IndexType InputSimdWidth = sizeof(vec_t) / sizeof(InputType); + return !stream.fail(); + } - static_assert(PaddedInputDimensions % InputSimdWidth == 0); + // Write network parameters + bool write_parameters(std::ostream& stream) const { + write_little_endian(stream, biases, OutputDimensions); - constexpr IndexType NumChunks = PaddedInputDimensions / InputSimdWidth; - vec_t sum0 = vec_setzero(); - const auto row0 = reinterpret_cast(&weights[0]); + for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) + write_little_endian(stream, weights[get_weight_index(i)]); - for (int j = 0; j < (int)NumChunks; ++j) - { - const vec_t in = inputVector[j]; - vec_add_dpbusd_32(sum0, in, row0[j]); + return !stream.fail(); } - output[0] = vec_hadd(sum0, biases[0]); - -# undef vec_setzero -# undef vec_set_32 -# undef vec_add_dpbusd_32 -# undef vec_add_dpbusd_32x2 -# undef vec_hadd - - } + // Forward propagation + void propagate(const InputType* input, OutputType* output) const { + +#if defined(USE_SSSE3) + + if constexpr (OutputDimensions > 1) { + + #if defined(USE_AVX512) + using vec_t = __m512i; + #define vec_setzero _mm512_setzero_si512 + #define vec_set_32 _mm512_set1_epi32 + #define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32 + #define vec_add_dpbusd_32x2 Simd::m512_add_dpbusd_epi32x2 + #define vec_hadd Simd::m512_hadd + #elif defined(USE_AVX2) + using vec_t = __m256i; + #define vec_setzero _mm256_setzero_si256 + #define vec_set_32 _mm256_set1_epi32 + #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32 + #define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2 + #define vec_hadd Simd::m256_hadd + #elif defined(USE_SSSE3) + using vec_t = __m128i; + #define vec_setzero _mm_setzero_si128 + #define vec_set_32 _mm_set1_epi32 + #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32 + #define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2 + #define vec_hadd Simd::m128_hadd + #endif + + static constexpr IndexType OutputSimdWidth = sizeof(vec_t) / sizeof(OutputType); + + static_assert(OutputDimensions % OutputSimdWidth == 0); + + constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 8) / 4; + constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth; + + const auto input32 = reinterpret_cast(input); + const vec_t* biasvec = reinterpret_cast(biases); + vec_t acc[NumRegs]; + for (IndexType k = 0; k < NumRegs; ++k) acc[k] = biasvec[k]; + + for (IndexType i = 0; i < NumChunks; i += 2) { + const vec_t in0 = vec_set_32(input32[i + 0]); + const vec_t in1 = vec_set_32(input32[i + 1]); + const auto col0 = + reinterpret_cast(&weights[(i + 0) * OutputDimensions * 4]); + const auto col1 = + reinterpret_cast(&weights[(i + 1) * OutputDimensions * 4]); + for (IndexType k = 0; k < NumRegs; ++k) + vec_add_dpbusd_32x2(acc[k], in0, col0[k], in1, col1[k]); + } + + vec_t* outptr = reinterpret_cast(output); + for (IndexType k = 0; k < NumRegs; ++k) outptr[k] = acc[k]; + + #undef vec_setzero + #undef vec_set_32 + #undef vec_add_dpbusd_32 + #undef vec_add_dpbusd_32x2 + #undef vec_hadd + + } else if constexpr (OutputDimensions == 1) { + + // We cannot use AVX512 for the last layer because there's only 32 inputs and the buffer is not padded to 64 elements. + #if defined(USE_AVX2) + using vec_t = __m256i; + #define vec_setzero _mm256_setzero_si256 + #define vec_set_32 _mm256_set1_epi32 + #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32 + #define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2 + #define vec_hadd Simd::m256_hadd + #elif defined(USE_SSSE3) + using vec_t = __m128i; + #define vec_setzero _mm_setzero_si128 + #define vec_set_32 _mm_set1_epi32 + #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32 + #define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2 + #define vec_hadd Simd::m128_hadd + #endif + + const auto inputVector = reinterpret_cast(input); + + static constexpr IndexType InputSimdWidth = sizeof(vec_t) / sizeof(InputType); + + static_assert(PaddedInputDimensions % InputSimdWidth == 0); + + constexpr IndexType NumChunks = PaddedInputDimensions / InputSimdWidth; + vec_t sum0 = vec_setzero(); + const auto row0 = reinterpret_cast(&weights[0]); + + for (int j = 0; j < (int) NumChunks; ++j) { + const vec_t in = inputVector[j]; + vec_add_dpbusd_32(sum0, in, row0[j]); + } + output[0] = vec_hadd(sum0, biases[0]); + + #undef vec_setzero + #undef vec_set_32 + #undef vec_add_dpbusd_32 + #undef vec_add_dpbusd_32x2 + #undef vec_hadd + } #else - // Use old implementation for the other architectures. - affine_transform_non_ssse3< - InputDimensions, - PaddedInputDimensions, - OutputDimensions>(output, weights, biases, input); + // Use old implementation for the other architectures. + affine_transform_non_ssse3( + output, weights, biases, input); #endif - } + } - private: - using BiasType = OutputType; - using WeightType = std::int8_t; + private: + using BiasType = OutputType; + using WeightType = std::int8_t; - alignas(CacheLineSize) BiasType biases[OutputDimensions]; - alignas(CacheLineSize) WeightType weights[OutputDimensions * PaddedInputDimensions]; - }; + alignas(CacheLineSize) BiasType biases[OutputDimensions]; + alignas(CacheLineSize) WeightType weights[OutputDimensions * PaddedInputDimensions]; + }; } // namespace Stockfish::Eval::NNUE::Layers -#endif // #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED +#endif // #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED diff --git a/src/nnue/layers/affine_transform_sparse_input.h b/src/nnue/layers/affine_transform_sparse_input.h index c9894f5d96e..b2ef218e358 100644 --- a/src/nnue/layers/affine_transform_sparse_input.h +++ b/src/nnue/layers/affine_transform_sparse_input.h @@ -38,241 +38,233 @@ namespace Stockfish::Eval::NNUE::Layers { #if (USE_SSSE3 | (USE_NEON >= 8)) - alignas(CacheLineSize) static inline const std::array, 256> lookup_indices = [](){ - std::array, 256> v{}; - for (unsigned i = 0; i < 256; ++i) - { - std::uint64_t j = i, k = 0; - while(j) - v[i][k++] = pop_lsb(j); - } - return v; - }(); - - // Find indices of nonzero numbers in an int32_t array - template - void find_nnz(const std::int32_t* input, std::uint16_t* out, IndexType& count_out) { -#if defined (USE_SSSE3) - #if defined (USE_AVX512) + alignas(CacheLineSize) static inline const + std::array, 256> lookup_indices = []() { + std::array, 256> v{}; + for (unsigned i = 0; i < 256; ++i) { + std::uint64_t j = i, k = 0; + while (j) v[i][k++] = pop_lsb(j); + } + return v; + }(); + + // Find indices of nonzero numbers in an int32_t array + template + void find_nnz(const std::int32_t* input, std::uint16_t* out, IndexType& count_out) { + #if defined(USE_SSSE3) + #if defined(USE_AVX512) using vec_t = __m512i; - #define vec_nnz(a) _mm512_cmpgt_epi32_mask(a, _mm512_setzero_si512()) - #elif defined (USE_AVX2) + #define vec_nnz(a) _mm512_cmpgt_epi32_mask(a, _mm512_setzero_si512()) + #elif defined(USE_AVX2) using vec_t = __m256i; - #if defined(USE_VNNI) && !defined(USE_AVXVNNI) - #define vec_nnz(a) _mm256_cmpgt_epi32_mask(a, _mm256_setzero_si256()) - #else - #define vec_nnz(a) _mm256_movemask_ps(_mm256_castsi256_ps(_mm256_cmpgt_epi32(a, _mm256_setzero_si256()))) - #endif - #elif defined (USE_SSSE3) + #if defined(USE_VNNI) && !defined(USE_AVXVNNI) + #define vec_nnz(a) _mm256_cmpgt_epi32_mask(a, _mm256_setzero_si256()) + #else + #define vec_nnz(a) \ + _mm256_movemask_ps( \ + _mm256_castsi256_ps(_mm256_cmpgt_epi32(a, _mm256_setzero_si256()))) + #endif + #elif defined(USE_SSSE3) using vec_t = __m128i; - #define vec_nnz(a) _mm_movemask_ps(_mm_castsi128_ps(_mm_cmpgt_epi32(a, _mm_setzero_si128()))) + #define vec_nnz(a) \ + _mm_movemask_ps(_mm_castsi128_ps(_mm_cmpgt_epi32(a, _mm_setzero_si128()))) + #endif + using vec128_t = __m128i; + #define vec128_zero _mm_setzero_si128() + #define vec128_set_16(a) _mm_set1_epi16(a) + #define vec128_load(a) _mm_load_si128(a) + #define vec128_storeu(a, b) _mm_storeu_si128(a, b) + #define vec128_add(a, b) _mm_add_epi16(a, b) + #elif defined(USE_NEON) + using vec_t = uint32x4_t; + static const std::uint32_t Mask[4] = {1, 2, 4, 8}; + #define vec_nnz(a) vaddvq_u32(vandq_u32(vtstq_u32(a, a), vld1q_u32(Mask))) + using vec128_t = uint16x8_t; + #define vec128_zero vdupq_n_u16(0) + #define vec128_set_16(a) vdupq_n_u16(a) + #define vec128_load(a) vld1q_u16(reinterpret_cast(a)) + #define vec128_storeu(a, b) vst1q_u16(reinterpret_cast(a), b) + #define vec128_add(a, b) vaddq_u16(a, b) #endif - using vec128_t = __m128i; - #define vec128_zero _mm_setzero_si128() - #define vec128_set_16(a) _mm_set1_epi16(a) - #define vec128_load(a) _mm_load_si128(a) - #define vec128_storeu(a, b) _mm_storeu_si128(a, b) - #define vec128_add(a, b) _mm_add_epi16(a, b) -#elif defined (USE_NEON) - using vec_t = uint32x4_t; - static const std::uint32_t Mask[4] = {1, 2, 4, 8}; - #define vec_nnz(a) vaddvq_u32(vandq_u32(vtstq_u32(a, a), vld1q_u32(Mask))) - using vec128_t = uint16x8_t; - #define vec128_zero vdupq_n_u16(0) - #define vec128_set_16(a) vdupq_n_u16(a) - #define vec128_load(a) vld1q_u16(reinterpret_cast(a)) - #define vec128_storeu(a, b) vst1q_u16(reinterpret_cast(a), b) - #define vec128_add(a, b) vaddq_u16(a, b) -#endif - constexpr IndexType InputSimdWidth = sizeof(vec_t) / sizeof(std::int32_t); - // Inputs are processed InputSimdWidth at a time and outputs are processed 8 at a time so we process in chunks of max(InputSimdWidth, 8) - constexpr IndexType ChunkSize = std::max(InputSimdWidth, 8); - constexpr IndexType NumChunks = InputDimensions / ChunkSize; - constexpr IndexType InputsPerChunk = ChunkSize / InputSimdWidth; - constexpr IndexType OutputsPerChunk = ChunkSize / 8; - - const auto inputVector = reinterpret_cast(input); - IndexType count = 0; - vec128_t base = vec128_zero; - const vec128_t increment = vec128_set_16(8); - for (IndexType i = 0; i < NumChunks; ++i) - { - // bitmask of nonzero values in this chunk - unsigned nnz = 0; - for (IndexType j = 0; j < InputsPerChunk; ++j) - { - const vec_t inputChunk = inputVector[i * InputsPerChunk + j]; - nnz |= (unsigned)vec_nnz(inputChunk) << (j * InputSimdWidth); - } - for (IndexType j = 0; j < OutputsPerChunk; ++j) - { - const auto lookup = (nnz >> (j * 8)) & 0xFF; - const auto offsets = vec128_load(reinterpret_cast(&lookup_indices[lookup])); - vec128_storeu(reinterpret_cast(out + count), vec128_add(base, offsets)); - count += popcount(lookup); - base = vec128_add(base, increment); - } + constexpr IndexType InputSimdWidth = sizeof(vec_t) / sizeof(std::int32_t); + // Inputs are processed InputSimdWidth at a time and outputs are processed 8 at a time so we process in chunks of max(InputSimdWidth, 8) + constexpr IndexType ChunkSize = std::max(InputSimdWidth, 8); + constexpr IndexType NumChunks = InputDimensions / ChunkSize; + constexpr IndexType InputsPerChunk = ChunkSize / InputSimdWidth; + constexpr IndexType OutputsPerChunk = ChunkSize / 8; + + const auto inputVector = reinterpret_cast(input); + IndexType count = 0; + vec128_t base = vec128_zero; + const vec128_t increment = vec128_set_16(8); + for (IndexType i = 0; i < NumChunks; ++i) { + // bitmask of nonzero values in this chunk + unsigned nnz = 0; + for (IndexType j = 0; j < InputsPerChunk; ++j) { + const vec_t inputChunk = inputVector[i * InputsPerChunk + j]; + nnz |= (unsigned) vec_nnz(inputChunk) << (j * InputSimdWidth); + } + for (IndexType j = 0; j < OutputsPerChunk; ++j) { + const auto lookup = (nnz >> (j * 8)) & 0xFF; + const auto offsets = + vec128_load(reinterpret_cast(&lookup_indices[lookup])); + vec128_storeu(reinterpret_cast(out + count), vec128_add(base, offsets)); + count += popcount(lookup); + base = vec128_add(base, increment); + } + } + count_out = count; } - count_out = count; - } -# undef vec_nnz -# undef vec128_zero -# undef vec128_set_16 -# undef vec128_load -# undef vec128_storeu -# undef vec128_add + #undef vec_nnz + #undef vec128_zero + #undef vec128_set_16 + #undef vec128_load + #undef vec128_storeu + #undef vec128_add #endif - // Sparse input implementation - template - class AffineTransformSparseInput { - public: - // Input/output type - using InputType = std::uint8_t; - using OutputType = std::int32_t; + // Sparse input implementation + template class AffineTransformSparseInput { + public: + // Input/output type + using InputType = std::uint8_t; + using OutputType = std::int32_t; - // Number of input/output dimensions - static constexpr IndexType InputDimensions = InDims; - static constexpr IndexType OutputDimensions = OutDims; + // Number of input/output dimensions + static constexpr IndexType InputDimensions = InDims; + static constexpr IndexType OutputDimensions = OutDims; - static_assert(OutputDimensions % 16 == 0, "Only implemented for OutputDimensions divisible by 16."); + static_assert(OutputDimensions % 16 == 0, + "Only implemented for OutputDimensions divisible by 16."); - static constexpr IndexType PaddedInputDimensions = - ceil_to_multiple(InputDimensions, MaxSimdWidth); - static constexpr IndexType PaddedOutputDimensions = - ceil_to_multiple(OutputDimensions, MaxSimdWidth); + static constexpr IndexType PaddedInputDimensions = + ceil_to_multiple(InputDimensions, MaxSimdWidth); + static constexpr IndexType PaddedOutputDimensions = + ceil_to_multiple(OutputDimensions, MaxSimdWidth); #if (USE_SSSE3 | (USE_NEON >= 8)) - static constexpr IndexType ChunkSize = 4; + static constexpr IndexType ChunkSize = 4; #else - static constexpr IndexType ChunkSize = 1; + static constexpr IndexType ChunkSize = 1; #endif - using OutputBuffer = OutputType[PaddedOutputDimensions]; + using OutputBuffer = OutputType[PaddedOutputDimensions]; - // Hash value embedded in the evaluation file - static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) { - std::uint32_t hashValue = 0xCC03DAE4u; - hashValue += OutputDimensions; - hashValue ^= prevHash >> 1; - hashValue ^= prevHash << 31; - return hashValue; - } + // Hash value embedded in the evaluation file + static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) { + std::uint32_t hashValue = 0xCC03DAE4u; + hashValue += OutputDimensions; + hashValue ^= prevHash >> 1; + hashValue ^= prevHash << 31; + return hashValue; + } - static constexpr IndexType get_weight_index_scrambled(IndexType i) - { - return - (i / ChunkSize) % (PaddedInputDimensions / ChunkSize) * OutputDimensions * ChunkSize + - i / PaddedInputDimensions * ChunkSize + - i % ChunkSize; - } + static constexpr IndexType get_weight_index_scrambled(IndexType i) { + return (i / ChunkSize) % (PaddedInputDimensions / ChunkSize) * OutputDimensions * + ChunkSize + + i / PaddedInputDimensions * ChunkSize + i % ChunkSize; + } - static constexpr IndexType get_weight_index(IndexType i) - { + static constexpr IndexType get_weight_index(IndexType i) { #if (USE_SSSE3 | (USE_NEON >= 8)) - return get_weight_index_scrambled(i); + return get_weight_index_scrambled(i); #else - return i; + return i; #endif - } + } - // Read network parameters - bool read_parameters(std::istream& stream) { - read_little_endian(stream, biases, OutputDimensions); - for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) - weights[get_weight_index(i)] = read_little_endian(stream); + // Read network parameters + bool read_parameters(std::istream& stream) { + read_little_endian(stream, biases, OutputDimensions); + for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) + weights[get_weight_index(i)] = read_little_endian(stream); - return !stream.fail(); - } + return !stream.fail(); + } - // Write network parameters - bool write_parameters(std::ostream& stream) const { - write_little_endian(stream, biases, OutputDimensions); + // Write network parameters + bool write_parameters(std::ostream& stream) const { + write_little_endian(stream, biases, OutputDimensions); - for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) - write_little_endian(stream, weights[get_weight_index(i)]); + for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) + write_little_endian(stream, weights[get_weight_index(i)]); - return !stream.fail(); - } - // Forward propagation - void propagate( - const InputType* input, OutputType* output) const { + return !stream.fail(); + } + // Forward propagation + void propagate(const InputType* input, OutputType* output) const { #if (USE_SSSE3 | (USE_NEON >= 8)) -#if defined (USE_AVX512) - using invec_t = __m512i; - using outvec_t = __m512i; - #define vec_set_32 _mm512_set1_epi32 - #define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32 -#elif defined (USE_AVX2) - using invec_t = __m256i; - using outvec_t = __m256i; - #define vec_set_32 _mm256_set1_epi32 - #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32 -#elif defined (USE_SSSE3) - using invec_t = __m128i; - using outvec_t = __m128i; - #define vec_set_32 _mm_set1_epi32 - #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32 -#elif defined (USE_NEON_DOTPROD) - using invec_t = int8x16_t; - using outvec_t = int32x4_t; - #define vec_set_32(a) vreinterpretq_s8_u32(vdupq_n_u32(a)) - #define vec_add_dpbusd_32 Simd::dotprod_m128_add_dpbusd_epi32 -#elif defined (USE_NEON) - using invec_t = int8x16_t; - using outvec_t = int32x4_t; - #define vec_set_32(a) vreinterpretq_s8_u32(vdupq_n_u32(a)) - #define vec_add_dpbusd_32 Simd::neon_m128_add_dpbusd_epi32 -#endif - static constexpr IndexType OutputSimdWidth = sizeof(outvec_t) / sizeof(OutputType); - - constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 8) / ChunkSize; - constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth; - std::uint16_t nnz[NumChunks]; - IndexType count; - - const auto input32 = reinterpret_cast(input); - - // Find indices of nonzero 32bit blocks - find_nnz(input32, nnz, count); - - const outvec_t* biasvec = reinterpret_cast(biases); - outvec_t acc[NumRegs]; - for (IndexType k = 0; k < NumRegs; ++k) - acc[k] = biasvec[k]; - - for (IndexType j = 0; j < count; ++j) - { - const auto i = nnz[j]; - const invec_t in = vec_set_32(input32[i]); - const auto col = reinterpret_cast(&weights[i * OutputDimensions * ChunkSize]); - for (IndexType k = 0; k < NumRegs; ++k) - vec_add_dpbusd_32(acc[k], in, col[k]); - } - - outvec_t* outptr = reinterpret_cast(output); - for (IndexType k = 0; k < NumRegs; ++k) - outptr[k] = acc[k]; -# undef vec_set_32 -# undef vec_add_dpbusd_32 + #if defined(USE_AVX512) + using invec_t = __m512i; + using outvec_t = __m512i; + #define vec_set_32 _mm512_set1_epi32 + #define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32 + #elif defined(USE_AVX2) + using invec_t = __m256i; + using outvec_t = __m256i; + #define vec_set_32 _mm256_set1_epi32 + #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32 + #elif defined(USE_SSSE3) + using invec_t = __m128i; + using outvec_t = __m128i; + #define vec_set_32 _mm_set1_epi32 + #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32 + #elif defined(USE_NEON_DOTPROD) + using invec_t = int8x16_t; + using outvec_t = int32x4_t; + #define vec_set_32(a) vreinterpretq_s8_u32(vdupq_n_u32(a)) + #define vec_add_dpbusd_32 Simd::dotprod_m128_add_dpbusd_epi32 + #elif defined(USE_NEON) + using invec_t = int8x16_t; + using outvec_t = int32x4_t; + #define vec_set_32(a) vreinterpretq_s8_u32(vdupq_n_u32(a)) + #define vec_add_dpbusd_32 Simd::neon_m128_add_dpbusd_epi32 + #endif + static constexpr IndexType OutputSimdWidth = sizeof(outvec_t) / sizeof(OutputType); + + constexpr IndexType NumChunks = + ceil_to_multiple(InputDimensions, 8) / ChunkSize; + constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth; + std::uint16_t nnz[NumChunks]; + IndexType count; + + const auto input32 = reinterpret_cast(input); + + // Find indices of nonzero 32bit blocks + find_nnz(input32, nnz, count); + + const outvec_t* biasvec = reinterpret_cast(biases); + outvec_t acc[NumRegs]; + for (IndexType k = 0; k < NumRegs; ++k) acc[k] = biasvec[k]; + + for (IndexType j = 0; j < count; ++j) { + const auto i = nnz[j]; + const invec_t in = vec_set_32(input32[i]); + const auto col = + reinterpret_cast(&weights[i * OutputDimensions * ChunkSize]); + for (IndexType k = 0; k < NumRegs; ++k) vec_add_dpbusd_32(acc[k], in, col[k]); + } + + outvec_t* outptr = reinterpret_cast(output); + for (IndexType k = 0; k < NumRegs; ++k) outptr[k] = acc[k]; + #undef vec_set_32 + #undef vec_add_dpbusd_32 #else - // Use dense implementation for the other architectures. - affine_transform_non_ssse3< - InputDimensions, - PaddedInputDimensions, - OutputDimensions>(output, weights, biases, input); + // Use dense implementation for the other architectures. + affine_transform_non_ssse3( + output, weights, biases, input); #endif - } + } - private: - using BiasType = OutputType; - using WeightType = std::int8_t; + private: + using BiasType = OutputType; + using WeightType = std::int8_t; - alignas(CacheLineSize) BiasType biases[OutputDimensions]; - alignas(CacheLineSize) WeightType weights[OutputDimensions * PaddedInputDimensions]; - }; + alignas(CacheLineSize) BiasType biases[OutputDimensions]; + alignas(CacheLineSize) WeightType weights[OutputDimensions * PaddedInputDimensions]; + }; } // namespace Stockfish::Eval::NNUE::Layers -#endif // #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_SPARSE_INPUT_H_INCLUDED +#endif // #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_SPARSE_INPUT_H_INCLUDED diff --git a/src/nnue/layers/clipped_relu.h b/src/nnue/layers/clipped_relu.h index 2856bfb0a63..27041bdcac4 100644 --- a/src/nnue/layers/clipped_relu.h +++ b/src/nnue/layers/clipped_relu.h @@ -29,154 +29,151 @@ namespace Stockfish::Eval::NNUE::Layers { - // Clipped ReLU - template - class ClippedReLU { - public: - // Input/output type - using InputType = std::int32_t; - using OutputType = std::uint8_t; - - // Number of input/output dimensions - static constexpr IndexType InputDimensions = InDims; - static constexpr IndexType OutputDimensions = InputDimensions; - static constexpr IndexType PaddedOutputDimensions = - ceil_to_multiple(OutputDimensions, 32); - - using OutputBuffer = OutputType[PaddedOutputDimensions]; - - // Hash value embedded in the evaluation file - static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) { - std::uint32_t hashValue = 0x538D24C7u; - hashValue += prevHash; - return hashValue; - } - - // Read network parameters - bool read_parameters(std::istream&) { - return true; - } - - // Write network parameters - bool write_parameters(std::ostream&) const { - return true; - } - - // Forward propagation - void propagate( - const InputType* input, OutputType* output) const { - - #if defined(USE_AVX2) - if constexpr (InputDimensions % SimdWidth == 0) { - constexpr IndexType NumChunks = InputDimensions / SimdWidth; - const __m256i Zero = _mm256_setzero_si256(); - const __m256i Offsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); - const auto in = reinterpret_cast(input); - const auto out = reinterpret_cast<__m256i*>(output); - for (IndexType i = 0; i < NumChunks; ++i) { - const __m256i words0 = _mm256_srai_epi16(_mm256_packs_epi32( - _mm256_load_si256(&in[i * 4 + 0]), - _mm256_load_si256(&in[i * 4 + 1])), WeightScaleBits); - const __m256i words1 = _mm256_srai_epi16(_mm256_packs_epi32( - _mm256_load_si256(&in[i * 4 + 2]), - _mm256_load_si256(&in[i * 4 + 3])), WeightScaleBits); - _mm256_store_si256(&out[i], _mm256_permutevar8x32_epi32(_mm256_max_epi8( - _mm256_packs_epi16(words0, words1), Zero), Offsets)); + // Clipped ReLU + template class ClippedReLU { + public: + // Input/output type + using InputType = std::int32_t; + using OutputType = std::uint8_t; + + // Number of input/output dimensions + static constexpr IndexType InputDimensions = InDims; + static constexpr IndexType OutputDimensions = InputDimensions; + static constexpr IndexType PaddedOutputDimensions = + ceil_to_multiple(OutputDimensions, 32); + + using OutputBuffer = OutputType[PaddedOutputDimensions]; + + // Hash value embedded in the evaluation file + static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) { + std::uint32_t hashValue = 0x538D24C7u; + hashValue += prevHash; + return hashValue; } - } else { - constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2); - const __m128i Zero = _mm_setzero_si128(); - const auto in = reinterpret_cast(input); - const auto out = reinterpret_cast<__m128i*>(output); - for (IndexType i = 0; i < NumChunks; ++i) { - const __m128i words0 = _mm_srai_epi16(_mm_packs_epi32( - _mm_load_si128(&in[i * 4 + 0]), - _mm_load_si128(&in[i * 4 + 1])), WeightScaleBits); - const __m128i words1 = _mm_srai_epi16(_mm_packs_epi32( - _mm_load_si128(&in[i * 4 + 2]), - _mm_load_si128(&in[i * 4 + 3])), WeightScaleBits); - const __m128i packedbytes = _mm_packs_epi16(words0, words1); - _mm_store_si128(&out[i], _mm_max_epi8(packedbytes, Zero)); + + // Read network parameters + bool read_parameters(std::istream&) { return true; } + + // Write network parameters + bool write_parameters(std::ostream&) const { return true; } + + // Forward propagation + void propagate(const InputType* input, OutputType* output) const { + +#if defined(USE_AVX2) + if constexpr (InputDimensions % SimdWidth == 0) { + constexpr IndexType NumChunks = InputDimensions / SimdWidth; + const __m256i Zero = _mm256_setzero_si256(); + const __m256i Offsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + const auto in = reinterpret_cast(input); + const auto out = reinterpret_cast<__m256i*>(output); + for (IndexType i = 0; i < NumChunks; ++i) { + const __m256i words0 = + _mm256_srai_epi16(_mm256_packs_epi32(_mm256_load_si256(&in[i * 4 + 0]), + _mm256_load_si256(&in[i * 4 + 1])), + WeightScaleBits); + const __m256i words1 = + _mm256_srai_epi16(_mm256_packs_epi32(_mm256_load_si256(&in[i * 4 + 2]), + _mm256_load_si256(&in[i * 4 + 3])), + WeightScaleBits); + _mm256_store_si256( + &out[i], + _mm256_permutevar8x32_epi32( + _mm256_max_epi8(_mm256_packs_epi16(words0, words1), Zero), Offsets)); + } + } else { + constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2); + const __m128i Zero = _mm_setzero_si128(); + const auto in = reinterpret_cast(input); + const auto out = reinterpret_cast<__m128i*>(output); + for (IndexType i = 0; i < NumChunks; ++i) { + const __m128i words0 = + _mm_srai_epi16(_mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), + _mm_load_si128(&in[i * 4 + 1])), + WeightScaleBits); + const __m128i words1 = + _mm_srai_epi16(_mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), + _mm_load_si128(&in[i * 4 + 3])), + WeightScaleBits); + const __m128i packedbytes = _mm_packs_epi16(words0, words1); + _mm_store_si128(&out[i], _mm_max_epi8(packedbytes, Zero)); + } + } + constexpr IndexType Start = InputDimensions % SimdWidth == 0 ? + InputDimensions / SimdWidth * SimdWidth : + InputDimensions / (SimdWidth / 2) * (SimdWidth / 2); + +#elif defined(USE_SSE2) + constexpr IndexType NumChunks = InputDimensions / SimdWidth; + + #ifdef USE_SSE41 + const __m128i Zero = _mm_setzero_si128(); + #else + const __m128i k0x80s = _mm_set1_epi8(-128); + #endif + + const auto in = reinterpret_cast(input); + const auto out = reinterpret_cast<__m128i*>(output); + for (IndexType i = 0; i < NumChunks; ++i) { + const __m128i words0 = _mm_srai_epi16( + _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])), + WeightScaleBits); + const __m128i words1 = _mm_srai_epi16( + _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])), + WeightScaleBits); + const __m128i packedbytes = _mm_packs_epi16(words0, words1); + _mm_store_si128(&out[i], + + #ifdef USE_SSE41 + _mm_max_epi8(packedbytes, Zero) + #else + _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s) + #endif + + ); + } + constexpr IndexType Start = NumChunks * SimdWidth; + +#elif defined(USE_MMX) + constexpr IndexType NumChunks = InputDimensions / SimdWidth; + const __m64 k0x80s = _mm_set1_pi8(-128); + const auto in = reinterpret_cast(input); + const auto out = reinterpret_cast<__m64*>(output); + for (IndexType i = 0; i < NumChunks; ++i) { + const __m64 words0 = + _mm_srai_pi16(_mm_packs_pi32(in[i * 4 + 0], in[i * 4 + 1]), WeightScaleBits); + const __m64 words1 = + _mm_srai_pi16(_mm_packs_pi32(in[i * 4 + 2], in[i * 4 + 3]), WeightScaleBits); + const __m64 packedbytes = _mm_packs_pi16(words0, words1); + out[i] = _mm_subs_pi8(_mm_adds_pi8(packedbytes, k0x80s), k0x80s); + } + _mm_empty(); + constexpr IndexType Start = NumChunks * SimdWidth; + +#elif defined(USE_NEON) + constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2); + const int8x8_t Zero = {0}; + const auto in = reinterpret_cast(input); + const auto out = reinterpret_cast(output); + for (IndexType i = 0; i < NumChunks; ++i) { + int16x8_t shifted; + const auto pack = reinterpret_cast(&shifted); + pack[0] = vqshrn_n_s32(in[i * 2 + 0], WeightScaleBits); + pack[1] = vqshrn_n_s32(in[i * 2 + 1], WeightScaleBits); + out[i] = vmax_s8(vqmovn_s16(shifted), Zero); + } + constexpr IndexType Start = NumChunks * (SimdWidth / 2); +#else + constexpr IndexType Start = 0; +#endif + + for (IndexType i = Start; i < InputDimensions; ++i) { + output[i] = + static_cast(std::max(0, std::min(127, input[i] >> WeightScaleBits))); + } } - } - constexpr IndexType Start = - InputDimensions % SimdWidth == 0 - ? InputDimensions / SimdWidth * SimdWidth - : InputDimensions / (SimdWidth / 2) * (SimdWidth / 2); - - #elif defined(USE_SSE2) - constexpr IndexType NumChunks = InputDimensions / SimdWidth; - - #ifdef USE_SSE41 - const __m128i Zero = _mm_setzero_si128(); - #else - const __m128i k0x80s = _mm_set1_epi8(-128); - #endif - - const auto in = reinterpret_cast(input); - const auto out = reinterpret_cast<__m128i*>(output); - for (IndexType i = 0; i < NumChunks; ++i) { - const __m128i words0 = _mm_srai_epi16(_mm_packs_epi32( - _mm_load_si128(&in[i * 4 + 0]), - _mm_load_si128(&in[i * 4 + 1])), WeightScaleBits); - const __m128i words1 = _mm_srai_epi16(_mm_packs_epi32( - _mm_load_si128(&in[i * 4 + 2]), - _mm_load_si128(&in[i * 4 + 3])), WeightScaleBits); - const __m128i packedbytes = _mm_packs_epi16(words0, words1); - _mm_store_si128(&out[i], - - #ifdef USE_SSE41 - _mm_max_epi8(packedbytes, Zero) - #else - _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s) - #endif - - ); - } - constexpr IndexType Start = NumChunks * SimdWidth; - - #elif defined(USE_MMX) - constexpr IndexType NumChunks = InputDimensions / SimdWidth; - const __m64 k0x80s = _mm_set1_pi8(-128); - const auto in = reinterpret_cast(input); - const auto out = reinterpret_cast<__m64*>(output); - for (IndexType i = 0; i < NumChunks; ++i) { - const __m64 words0 = _mm_srai_pi16( - _mm_packs_pi32(in[i * 4 + 0], in[i * 4 + 1]), - WeightScaleBits); - const __m64 words1 = _mm_srai_pi16( - _mm_packs_pi32(in[i * 4 + 2], in[i * 4 + 3]), - WeightScaleBits); - const __m64 packedbytes = _mm_packs_pi16(words0, words1); - out[i] = _mm_subs_pi8(_mm_adds_pi8(packedbytes, k0x80s), k0x80s); - } - _mm_empty(); - constexpr IndexType Start = NumChunks * SimdWidth; - - #elif defined(USE_NEON) - constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2); - const int8x8_t Zero = {0}; - const auto in = reinterpret_cast(input); - const auto out = reinterpret_cast(output); - for (IndexType i = 0; i < NumChunks; ++i) { - int16x8_t shifted; - const auto pack = reinterpret_cast(&shifted); - pack[0] = vqshrn_n_s32(in[i * 2 + 0], WeightScaleBits); - pack[1] = vqshrn_n_s32(in[i * 2 + 1], WeightScaleBits); - out[i] = vmax_s8(vqmovn_s16(shifted), Zero); - } - constexpr IndexType Start = NumChunks * (SimdWidth / 2); - #else - constexpr IndexType Start = 0; - #endif - - for (IndexType i = Start; i < InputDimensions; ++i) { - output[i] = static_cast( - std::max(0, std::min(127, input[i] >> WeightScaleBits))); - } - } - }; + }; } // namespace Stockfish::Eval::NNUE::Layers -#endif // NNUE_LAYERS_CLIPPED_RELU_H_INCLUDED +#endif // NNUE_LAYERS_CLIPPED_RELU_H_INCLUDED diff --git a/src/nnue/layers/simd.h b/src/nnue/layers/simd.h index f478cd7819f..22f4bf42e92 100644 --- a/src/nnue/layers/simd.h +++ b/src/nnue/layers/simd.h @@ -20,30 +20,30 @@ #define STOCKFISH_SIMD_H_INCLUDED #if defined(USE_AVX2) -# include + #include #elif defined(USE_SSE41) -# include + #include #elif defined(USE_SSSE3) -# include + #include #elif defined(USE_SSE2) -# include + #include #elif defined(USE_MMX) -# include + #include #elif defined(USE_NEON) -# include + #include #endif namespace Stockfish::Simd { -#if defined (USE_AVX512) +#if defined(USE_AVX512) [[maybe_unused]] static int m512_hadd(__m512i sum, int bias) { - return _mm512_reduce_add_epi32(sum) + bias; + return _mm512_reduce_add_epi32(sum) + bias; } /* @@ -61,186 +61,168 @@ namespace Stockfish::Simd { reduce_add_epi32(zmm0.i128[3]), reduce_add_epi32(zmm1.i128[3]), reduce_add_epi32(zmm2.i128[3]), reduce_add_epi32(zmm3.i128[3]) ] */ - [[maybe_unused]] static __m512i m512_hadd128x16_interleave( - __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) { + [[maybe_unused]] static __m512i m512_hadd128x16_interleave(__m512i sum0, __m512i sum1, + __m512i sum2, __m512i sum3) { - __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1); - __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1); + __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1); + __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1); - __m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3); - __m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3); + __m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3); + __m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3); - __m512i sum01 = _mm512_add_epi32(sum01a, sum01b); - __m512i sum23 = _mm512_add_epi32(sum23a, sum23b); + __m512i sum01 = _mm512_add_epi32(sum01a, sum01b); + __m512i sum23 = _mm512_add_epi32(sum23a, sum23b); - __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23); - __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23); + __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23); + __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23); - return _mm512_add_epi32(sum0123a, sum0123b); + return _mm512_add_epi32(sum0123a, sum0123b); } - [[maybe_unused]] static void m512_add_dpbusd_epi32( - __m512i& acc, - __m512i a, - __m512i b) { + [[maybe_unused]] static void m512_add_dpbusd_epi32(__m512i& acc, __m512i a, __m512i b) { -# if defined (USE_VNNI) - acc = _mm512_dpbusd_epi32(acc, a, b); -# else - __m512i product0 = _mm512_maddubs_epi16(a, b); - product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1)); - acc = _mm512_add_epi32(acc, product0); -# endif + #if defined(USE_VNNI) + acc = _mm512_dpbusd_epi32(acc, a, b); + #else + __m512i product0 = _mm512_maddubs_epi16(a, b); + product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1)); + acc = _mm512_add_epi32(acc, product0); + #endif } - [[maybe_unused]] static void m512_add_dpbusd_epi32x2( - __m512i& acc, - __m512i a0, __m512i b0, - __m512i a1, __m512i b1) { + [[maybe_unused]] static void m512_add_dpbusd_epi32x2(__m512i& acc, __m512i a0, __m512i b0, + __m512i a1, __m512i b1) { -# if defined (USE_VNNI) - acc = _mm512_dpbusd_epi32(acc, a0, b0); - acc = _mm512_dpbusd_epi32(acc, a1, b1); -# else - __m512i product0 = _mm512_maddubs_epi16(a0, b0); - __m512i product1 = _mm512_maddubs_epi16(a1, b1); - product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1)); - product1 = _mm512_madd_epi16(product1, _mm512_set1_epi16(1)); - acc = _mm512_add_epi32(acc, _mm512_add_epi32(product0, product1)); -# endif + #if defined(USE_VNNI) + acc = _mm512_dpbusd_epi32(acc, a0, b0); + acc = _mm512_dpbusd_epi32(acc, a1, b1); + #else + __m512i product0 = _mm512_maddubs_epi16(a0, b0); + __m512i product1 = _mm512_maddubs_epi16(a1, b1); + product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1)); + product1 = _mm512_madd_epi16(product1, _mm512_set1_epi16(1)); + acc = _mm512_add_epi32(acc, _mm512_add_epi32(product0, product1)); + #endif } #endif -#if defined (USE_AVX2) +#if defined(USE_AVX2) [[maybe_unused]] static int m256_hadd(__m256i sum, int bias) { - __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1)); - sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC)); - sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB)); - return _mm_cvtsi128_si32(sum128) + bias; - } - - [[maybe_unused]] static void m256_add_dpbusd_epi32( - __m256i& acc, - __m256i a, - __m256i b) { - -# if defined (USE_VNNI) - acc = _mm256_dpbusd_epi32(acc, a, b); -# else - __m256i product0 = _mm256_maddubs_epi16(a, b); - product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1)); - acc = _mm256_add_epi32(acc, product0); -# endif - } - - [[maybe_unused]] static void m256_add_dpbusd_epi32x2( - __m256i& acc, - __m256i a0, __m256i b0, - __m256i a1, __m256i b1) { - -# if defined (USE_VNNI) - acc = _mm256_dpbusd_epi32(acc, a0, b0); - acc = _mm256_dpbusd_epi32(acc, a1, b1); -# else - __m256i product0 = _mm256_maddubs_epi16(a0, b0); - __m256i product1 = _mm256_maddubs_epi16(a1, b1); - product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1)); - product1 = _mm256_madd_epi16(product1, _mm256_set1_epi16(1)); - acc = _mm256_add_epi32(acc, _mm256_add_epi32(product0, product1)); -# endif + __m128i sum128 = + _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1)); + sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC)); + sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB)); + return _mm_cvtsi128_si32(sum128) + bias; + } + + [[maybe_unused]] static void m256_add_dpbusd_epi32(__m256i& acc, __m256i a, __m256i b) { + + #if defined(USE_VNNI) + acc = _mm256_dpbusd_epi32(acc, a, b); + #else + __m256i product0 = _mm256_maddubs_epi16(a, b); + product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1)); + acc = _mm256_add_epi32(acc, product0); + #endif + } + + [[maybe_unused]] static void m256_add_dpbusd_epi32x2(__m256i& acc, __m256i a0, __m256i b0, + __m256i a1, __m256i b1) { + + #if defined(USE_VNNI) + acc = _mm256_dpbusd_epi32(acc, a0, b0); + acc = _mm256_dpbusd_epi32(acc, a1, b1); + #else + __m256i product0 = _mm256_maddubs_epi16(a0, b0); + __m256i product1 = _mm256_maddubs_epi16(a1, b1); + product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1)); + product1 = _mm256_madd_epi16(product1, _mm256_set1_epi16(1)); + acc = _mm256_add_epi32(acc, _mm256_add_epi32(product0, product1)); + #endif } #endif -#if defined (USE_SSSE3) +#if defined(USE_SSSE3) [[maybe_unused]] static int m128_hadd(__m128i sum, int bias) { - sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC - sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB - return _mm_cvtsi128_si32(sum) + bias; + sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC + sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB + return _mm_cvtsi128_si32(sum) + bias; } - [[maybe_unused]] static void m128_add_dpbusd_epi32( - __m128i& acc, - __m128i a, - __m128i b) { + [[maybe_unused]] static void m128_add_dpbusd_epi32(__m128i& acc, __m128i a, __m128i b) { - __m128i product0 = _mm_maddubs_epi16(a, b); - product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1)); - acc = _mm_add_epi32(acc, product0); + __m128i product0 = _mm_maddubs_epi16(a, b); + product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1)); + acc = _mm_add_epi32(acc, product0); } - [[maybe_unused]] static void m128_add_dpbusd_epi32x2( - __m128i& acc, - __m128i a0, __m128i b0, - __m128i a1, __m128i b1) { + [[maybe_unused]] static void m128_add_dpbusd_epi32x2(__m128i& acc, __m128i a0, __m128i b0, + __m128i a1, __m128i b1) { - __m128i product0 = _mm_maddubs_epi16(a0, b0); - __m128i product1 = _mm_maddubs_epi16(a1, b1); - product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1)); - product1 = _mm_madd_epi16(product1, _mm_set1_epi16(1)); - acc = _mm_add_epi32(acc, _mm_add_epi32(product0, product1)); + __m128i product0 = _mm_maddubs_epi16(a0, b0); + __m128i product1 = _mm_maddubs_epi16(a1, b1); + product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1)); + product1 = _mm_madd_epi16(product1, _mm_set1_epi16(1)); + acc = _mm_add_epi32(acc, _mm_add_epi32(product0, product1)); } #endif -#if defined (USE_NEON_DOTPROD) +#if defined(USE_NEON_DOTPROD) - [[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32x2( - int32x4_t& acc, - int8x16_t a0, int8x16_t b0, - int8x16_t a1, int8x16_t b1) { + [[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32x2(int32x4_t& acc, int8x16_t a0, + int8x16_t b0, int8x16_t a1, + int8x16_t b1) { acc = vdotq_s32(acc, a0, b0); acc = vdotq_s32(acc, a1, b1); } - [[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32( - int32x4_t& acc, - int8x16_t a, int8x16_t b) { + [[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32(int32x4_t& acc, int8x16_t a, + int8x16_t b) { acc = vdotq_s32(acc, a, b); } #endif -#if defined (USE_NEON) +#if defined(USE_NEON) [[maybe_unused]] static int neon_m128_reduce_add_epi32(int32x4_t s) { -# if USE_NEON >= 8 - return vaddvq_s32(s); -# else - return s[0] + s[1] + s[2] + s[3]; -# endif + #if USE_NEON >= 8 + return vaddvq_s32(s); + #else + return s[0] + s[1] + s[2] + s[3]; + #endif } [[maybe_unused]] static int neon_m128_hadd(int32x4_t sum, int bias) { - return neon_m128_reduce_add_epi32(sum) + bias; + return neon_m128_reduce_add_epi32(sum) + bias; } - [[maybe_unused]] static void neon_m128_add_dpbusd_epi32x2( - int32x4_t& acc, - int8x8_t a0, int8x8_t b0, - int8x8_t a1, int8x8_t b1) { + [[maybe_unused]] static void neon_m128_add_dpbusd_epi32x2(int32x4_t& acc, int8x8_t a0, + int8x8_t b0, int8x8_t a1, + int8x8_t b1) { - int16x8_t product = vmull_s8(a0, b0); - product = vmlal_s8(product, a1, b1); - acc = vpadalq_s16(acc, product); + int16x8_t product = vmull_s8(a0, b0); + product = vmlal_s8(product, a1, b1); + acc = vpadalq_s16(acc, product); } #endif #if USE_NEON >= 8 - [[maybe_unused]] static void neon_m128_add_dpbusd_epi32( - int32x4_t& acc, - int8x16_t a, int8x16_t b) { + [[maybe_unused]] static void neon_m128_add_dpbusd_epi32(int32x4_t& acc, int8x16_t a, + int8x16_t b) { - int16x8_t product0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); - int16x8_t product1 = vmull_high_s8(a, b); - int16x8_t sum = vpaddq_s16(product0, product1); - acc = vpadalq_s16(acc, sum); + int16x8_t product0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); + int16x8_t product1 = vmull_high_s8(a, b); + int16x8_t sum = vpaddq_s16(product0, product1); + acc = vpadalq_s16(acc, sum); } #endif } -#endif // STOCKFISH_SIMD_H_INCLUDED +#endif // STOCKFISH_SIMD_H_INCLUDED diff --git a/src/nnue/layers/sqr_clipped_relu.h b/src/nnue/layers/sqr_clipped_relu.h index 503b283b25e..55d3a83526d 100644 --- a/src/nnue/layers/sqr_clipped_relu.h +++ b/src/nnue/layers/sqr_clipped_relu.h @@ -29,80 +29,73 @@ namespace Stockfish::Eval::NNUE::Layers { - // Clipped ReLU - template - class SqrClippedReLU { - public: - // Input/output type - using InputType = std::int32_t; - using OutputType = std::uint8_t; - - // Number of input/output dimensions - static constexpr IndexType InputDimensions = InDims; - static constexpr IndexType OutputDimensions = InputDimensions; - static constexpr IndexType PaddedOutputDimensions = - ceil_to_multiple(OutputDimensions, 32); - - using OutputBuffer = OutputType[PaddedOutputDimensions]; - - // Hash value embedded in the evaluation file - static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) { - std::uint32_t hashValue = 0x538D24C7u; - hashValue += prevHash; - return hashValue; - } - - // Read network parameters - bool read_parameters(std::istream&) { - return true; - } - - // Write network parameters - bool write_parameters(std::ostream&) const { - return true; - } - - // Forward propagation - void propagate( - const InputType* input, OutputType* output) const { - - #if defined(USE_SSE2) - constexpr IndexType NumChunks = InputDimensions / 16; - - static_assert(WeightScaleBits == 6); - const auto in = reinterpret_cast(input); - const auto out = reinterpret_cast<__m128i*>(output); - for (IndexType i = 0; i < NumChunks; ++i) { - __m128i words0 = _mm_packs_epi32( - _mm_load_si128(&in[i * 4 + 0]), - _mm_load_si128(&in[i * 4 + 1])); - __m128i words1 = _mm_packs_epi32( - _mm_load_si128(&in[i * 4 + 2]), - _mm_load_si128(&in[i * 4 + 3])); - - // We shift by WeightScaleBits * 2 = 12 and divide by 128 - // which is an additional shift-right of 7, meaning 19 in total. - // MulHi strips the lower 16 bits so we need to shift out 3 more to match. - words0 = _mm_srli_epi16(_mm_mulhi_epi16(words0, words0), 3); - words1 = _mm_srli_epi16(_mm_mulhi_epi16(words1, words1), 3); - - _mm_store_si128(&out[i], _mm_packs_epi16(words0, words1)); - } - constexpr IndexType Start = NumChunks * 16; - - #else - constexpr IndexType Start = 0; - #endif - - for (IndexType i = Start; i < InputDimensions; ++i) { - output[i] = static_cast( - // really should be /127 but we need to make it fast - // needs to be accounted for in the trainer - std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128)); - } - } - }; + // Clipped ReLU + template class SqrClippedReLU { + public: + // Input/output type + using InputType = std::int32_t; + using OutputType = std::uint8_t; + + // Number of input/output dimensions + static constexpr IndexType InputDimensions = InDims; + static constexpr IndexType OutputDimensions = InputDimensions; + static constexpr IndexType PaddedOutputDimensions = + ceil_to_multiple(OutputDimensions, 32); + + using OutputBuffer = OutputType[PaddedOutputDimensions]; + + // Hash value embedded in the evaluation file + static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) { + std::uint32_t hashValue = 0x538D24C7u; + hashValue += prevHash; + return hashValue; + } + + // Read network parameters + bool read_parameters(std::istream&) { return true; } + + // Write network parameters + bool write_parameters(std::ostream&) const { return true; } + + // Forward propagation + void propagate(const InputType* input, OutputType* output) const { + +#if defined(USE_SSE2) + constexpr IndexType NumChunks = InputDimensions / 16; + + static_assert(WeightScaleBits == 6); + const auto in = reinterpret_cast(input); + const auto out = reinterpret_cast<__m128i*>(output); + for (IndexType i = 0; i < NumChunks; ++i) { + __m128i words0 = + _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])); + __m128i words1 = + _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])); + + // We shift by WeightScaleBits * 2 = 12 and divide by 128 + // which is an additional shift-right of 7, meaning 19 in total. + // MulHi strips the lower 16 bits so we need to shift out 3 more to match. + words0 = _mm_srli_epi16(_mm_mulhi_epi16(words0, words0), 3); + words1 = _mm_srli_epi16(_mm_mulhi_epi16(words1, words1), 3); + + _mm_store_si128(&out[i], _mm_packs_epi16(words0, words1)); + } + constexpr IndexType Start = NumChunks * 16; + +#else + constexpr IndexType Start = 0; +#endif + + for (IndexType i = Start; i < InputDimensions; ++i) { + output[i] = static_cast( + // really should be /127 but we need to make it fast + // needs to be accounted for in the trainer + std::min(127ll, + (((long long) input[i] * input[i]) >> (2 * WeightScaleBits)) / 128)); + } + } + }; } // namespace Stockfish::Eval::NNUE::Layers -#endif // NNUE_LAYERS_SQR_CLIPPED_RELU_H_INCLUDED +#endif // NNUE_LAYERS_SQR_CLIPPED_RELU_H_INCLUDED diff --git a/src/nnue/nnue_accumulator.h b/src/nnue/nnue_accumulator.h index 03fc3bd5cd8..12d22fc99c9 100644 --- a/src/nnue/nnue_accumulator.h +++ b/src/nnue/nnue_accumulator.h @@ -28,13 +28,13 @@ namespace Stockfish::Eval::NNUE { - // Class that holds the result of affine transformation of input features - struct alignas(CacheLineSize) Accumulator { - std::int16_t accumulation[2][TransformedFeatureDimensions]; - std::int32_t psqtAccumulation[2][PSQTBuckets]; - bool computed[2]; - }; + // Class that holds the result of affine transformation of input features + struct alignas(CacheLineSize) Accumulator { + std::int16_t accumulation[2][TransformedFeatureDimensions]; + std::int32_t psqtAccumulation[2][PSQTBuckets]; + bool computed[2]; + }; } // namespace Stockfish::Eval::NNUE -#endif // NNUE_ACCUMULATOR_H_INCLUDED +#endif // NNUE_ACCUMULATOR_H_INCLUDED diff --git a/src/nnue/nnue_architecture.h b/src/nnue/nnue_architecture.h index b50c52df31f..3d95d4465d9 100644 --- a/src/nnue/nnue_architecture.h +++ b/src/nnue/nnue_architecture.h @@ -34,102 +34,95 @@ namespace Stockfish::Eval::NNUE { -// Input features used in evaluation function -using FeatureSet = Features::HalfKAv2_hm; - -// Number of input feature dimensions after conversion -constexpr IndexType TransformedFeatureDimensions = 2048; -constexpr IndexType PSQTBuckets = 8; -constexpr IndexType LayerStacks = 8; - -struct Network -{ - static constexpr int FC_0_OUTPUTS = 15; - static constexpr int FC_1_OUTPUTS = 32; - - Layers::AffineTransformSparseInput fc_0; - Layers::SqrClippedReLU ac_sqr_0; - Layers::ClippedReLU ac_0; - Layers::AffineTransform fc_1; - Layers::ClippedReLU ac_1; - Layers::AffineTransform fc_2; - - // Hash value embedded in the evaluation file - static constexpr std::uint32_t get_hash_value() { - // input slice hash - std::uint32_t hashValue = 0xEC42E90Du; - hashValue ^= TransformedFeatureDimensions * 2; - - hashValue = decltype(fc_0)::get_hash_value(hashValue); - hashValue = decltype(ac_0)::get_hash_value(hashValue); - hashValue = decltype(fc_1)::get_hash_value(hashValue); - hashValue = decltype(ac_1)::get_hash_value(hashValue); - hashValue = decltype(fc_2)::get_hash_value(hashValue); - - return hashValue; - } - - // Read network parameters - bool read_parameters(std::istream& stream) { - return fc_0.read_parameters(stream) - && ac_0.read_parameters(stream) - && fc_1.read_parameters(stream) - && ac_1.read_parameters(stream) - && fc_2.read_parameters(stream); - } - - // Write network parameters - bool write_parameters(std::ostream& stream) const { - return fc_0.write_parameters(stream) - && ac_0.write_parameters(stream) - && fc_1.write_parameters(stream) - && ac_1.write_parameters(stream) - && fc_2.write_parameters(stream); - } - - std::int32_t propagate(const TransformedFeatureType* transformedFeatures) - { - struct alignas(CacheLineSize) Buffer - { - alignas(CacheLineSize) decltype(fc_0)::OutputBuffer fc_0_out; - alignas(CacheLineSize) decltype(ac_sqr_0)::OutputType ac_sqr_0_out[ceil_to_multiple(FC_0_OUTPUTS * 2, 32)]; - alignas(CacheLineSize) decltype(ac_0)::OutputBuffer ac_0_out; - alignas(CacheLineSize) decltype(fc_1)::OutputBuffer fc_1_out; - alignas(CacheLineSize) decltype(ac_1)::OutputBuffer ac_1_out; - alignas(CacheLineSize) decltype(fc_2)::OutputBuffer fc_2_out; - - Buffer() - { - std::memset(this, 0, sizeof(*this)); - } - }; + // Input features used in evaluation function + using FeatureSet = Features::HalfKAv2_hm; + + // Number of input feature dimensions after conversion + constexpr IndexType TransformedFeatureDimensions = 2048; + constexpr IndexType PSQTBuckets = 8; + constexpr IndexType LayerStacks = 8; + + struct Network { + static constexpr int FC_0_OUTPUTS = 15; + static constexpr int FC_1_OUTPUTS = 32; + + Layers::AffineTransformSparseInput fc_0; + Layers::SqrClippedReLU ac_sqr_0; + Layers::ClippedReLU ac_0; + Layers::AffineTransform fc_1; + Layers::ClippedReLU ac_1; + Layers::AffineTransform fc_2; + + // Hash value embedded in the evaluation file + static constexpr std::uint32_t get_hash_value() { + // input slice hash + std::uint32_t hashValue = 0xEC42E90Du; + hashValue ^= TransformedFeatureDimensions * 2; + + hashValue = decltype(fc_0)::get_hash_value(hashValue); + hashValue = decltype(ac_0)::get_hash_value(hashValue); + hashValue = decltype(fc_1)::get_hash_value(hashValue); + hashValue = decltype(ac_1)::get_hash_value(hashValue); + hashValue = decltype(fc_2)::get_hash_value(hashValue); + + return hashValue; + } + + // Read network parameters + bool read_parameters(std::istream& stream) { + return fc_0.read_parameters(stream) && ac_0.read_parameters(stream) && + fc_1.read_parameters(stream) && ac_1.read_parameters(stream) && + fc_2.read_parameters(stream); + } + + // Write network parameters + bool write_parameters(std::ostream& stream) const { + return fc_0.write_parameters(stream) && ac_0.write_parameters(stream) && + fc_1.write_parameters(stream) && ac_1.write_parameters(stream) && + fc_2.write_parameters(stream); + } + + std::int32_t propagate(const TransformedFeatureType* transformedFeatures) { + struct alignas(CacheLineSize) Buffer { + alignas(CacheLineSize) decltype(fc_0)::OutputBuffer fc_0_out; + alignas(CacheLineSize) decltype(ac_sqr_0)::OutputType + ac_sqr_0_out[ceil_to_multiple(FC_0_OUTPUTS * 2, 32)]; + alignas(CacheLineSize) decltype(ac_0)::OutputBuffer ac_0_out; + alignas(CacheLineSize) decltype(fc_1)::OutputBuffer fc_1_out; + alignas(CacheLineSize) decltype(ac_1)::OutputBuffer ac_1_out; + alignas(CacheLineSize) decltype(fc_2)::OutputBuffer fc_2_out; + + Buffer() { std::memset(this, 0, sizeof(*this)); } + }; #if defined(__clang__) && (__APPLE__) - // workaround for a bug reported with xcode 12 - static thread_local auto tlsBuffer = std::make_unique(); - // Access TLS only once, cache result. - Buffer& buffer = *tlsBuffer; + // workaround for a bug reported with xcode 12 + static thread_local auto tlsBuffer = std::make_unique(); + // Access TLS only once, cache result. + Buffer& buffer = *tlsBuffer; #else - alignas(CacheLineSize) static thread_local Buffer buffer; + alignas(CacheLineSize) static thread_local Buffer buffer; #endif - fc_0.propagate(transformedFeatures, buffer.fc_0_out); - ac_sqr_0.propagate(buffer.fc_0_out, buffer.ac_sqr_0_out); - ac_0.propagate(buffer.fc_0_out, buffer.ac_0_out); - std::memcpy(buffer.ac_sqr_0_out + FC_0_OUTPUTS, buffer.ac_0_out, FC_0_OUTPUTS * sizeof(decltype(ac_0)::OutputType)); - fc_1.propagate(buffer.ac_sqr_0_out, buffer.fc_1_out); - ac_1.propagate(buffer.fc_1_out, buffer.ac_1_out); - fc_2.propagate(buffer.ac_1_out, buffer.fc_2_out); - - // buffer.fc_0_out[FC_0_OUTPUTS] is such that 1.0 is equal to 127*(1< + #include #elif defined(USE_SSE41) -#include + #include #elif defined(USE_SSSE3) -#include + #include #elif defined(USE_SSE2) -#include + #include #elif defined(USE_MMX) -#include + #include #elif defined(USE_NEON) -#include + #include #endif namespace Stockfish::Eval::NNUE { - // Version of the evaluation file - constexpr std::uint32_t Version = 0x7AF32F20u; - - // Constant used in evaluation value calculation - constexpr int OutputScale = 16; - constexpr int WeightScaleBits = 6; - - // Size of cache line (in bytes) - constexpr std::size_t CacheLineSize = 64; - - constexpr const char Leb128MagicString[] = "COMPRESSED_LEB128"; - constexpr const std::size_t Leb128MagicStringSize = sizeof(Leb128MagicString) - 1; - - // SIMD width (in bytes) - #if defined(USE_AVX2) - constexpr std::size_t SimdWidth = 32; - - #elif defined(USE_SSE2) - constexpr std::size_t SimdWidth = 16; - - #elif defined(USE_MMX) - constexpr std::size_t SimdWidth = 8; - - #elif defined(USE_NEON) - constexpr std::size_t SimdWidth = 16; - #endif - - constexpr std::size_t MaxSimdWidth = 32; - - // Type of input feature after conversion - using TransformedFeatureType = std::uint8_t; - using IndexType = std::uint32_t; - - // Round n up to be a multiple of base - template - constexpr IntType ceil_to_multiple(IntType n, IntType base) { - return (n + base - 1) / base * base; - } - - - // read_little_endian() is our utility to read an integer (signed or unsigned, any size) - // from a stream in little-endian order. We swap the byte order after the read if - // necessary to return a result with the byte ordering of the compiling machine. - template - inline IntType read_little_endian(std::istream& stream) { - IntType result; - - if (IsLittleEndian) - stream.read(reinterpret_cast(&result), sizeof(IntType)); - else - { - std::uint8_t u[sizeof(IntType)]; - typename std::make_unsigned::type v = 0; - - stream.read(reinterpret_cast(u), sizeof(IntType)); - for (std::size_t i = 0; i < sizeof(IntType); ++i) - v = (v << 8) | u[sizeof(IntType) - i - 1]; - - std::memcpy(&result, &v, sizeof(IntType)); - } - - return result; - } - - - // write_little_endian() is our utility to write an integer (signed or unsigned, any size) - // to a stream in little-endian order. We swap the byte order before the write if - // necessary to always write in little endian order, independently of the byte - // ordering of the compiling machine. - template - inline void write_little_endian(std::ostream& stream, IntType value) { - - if (IsLittleEndian) - stream.write(reinterpret_cast(&value), sizeof(IntType)); - else - { - std::uint8_t u[sizeof(IntType)]; - typename std::make_unsigned::type v = value; - - std::size_t i = 0; - // if constexpr to silence the warning about shift by 8 - if constexpr (sizeof(IntType) > 1) - { - for (; i + 1 < sizeof(IntType); ++i) - { - u[i] = (std::uint8_t)v; - v >>= 8; + // Version of the evaluation file + constexpr std::uint32_t Version = 0x7AF32F20u; + + // Constant used in evaluation value calculation + constexpr int OutputScale = 16; + constexpr int WeightScaleBits = 6; + + // Size of cache line (in bytes) + constexpr std::size_t CacheLineSize = 64; + + constexpr const char Leb128MagicString[] = "COMPRESSED_LEB128"; + constexpr const std::size_t Leb128MagicStringSize = sizeof(Leb128MagicString) - 1; + +// SIMD width (in bytes) +#if defined(USE_AVX2) + constexpr std::size_t SimdWidth = 32; + +#elif defined(USE_SSE2) + constexpr std::size_t SimdWidth = 16; + +#elif defined(USE_MMX) + constexpr std::size_t SimdWidth = 8; + +#elif defined(USE_NEON) + constexpr std::size_t SimdWidth = 16; +#endif + + constexpr std::size_t MaxSimdWidth = 32; + + // Type of input feature after conversion + using TransformedFeatureType = std::uint8_t; + using IndexType = std::uint32_t; + + // Round n up to be a multiple of base + template constexpr IntType ceil_to_multiple(IntType n, IntType base) { + return (n + base - 1) / base * base; + } + + + // read_little_endian() is our utility to read an integer (signed or unsigned, any size) + // from a stream in little-endian order. We swap the byte order after the read if + // necessary to return a result with the byte ordering of the compiling machine. + template inline IntType read_little_endian(std::istream& stream) { + IntType result; + + if (IsLittleEndian) + stream.read(reinterpret_cast(&result), sizeof(IntType)); + else { + std::uint8_t u[sizeof(IntType)]; + typename std::make_unsigned::type v = 0; + + stream.read(reinterpret_cast(u), sizeof(IntType)); + for (std::size_t i = 0; i < sizeof(IntType); ++i) + v = (v << 8) | u[sizeof(IntType) - i - 1]; + + std::memcpy(&result, &v, sizeof(IntType)); + } + + return result; + } + + + // write_little_endian() is our utility to write an integer (signed or unsigned, any size) + // to a stream in little-endian order. We swap the byte order before the write if + // necessary to always write in little endian order, independently of the byte + // ordering of the compiling machine. + template + inline void write_little_endian(std::ostream& stream, IntType value) { + + if (IsLittleEndian) + stream.write(reinterpret_cast(&value), sizeof(IntType)); + else { + std::uint8_t u[sizeof(IntType)]; + typename std::make_unsigned::type v = value; + + std::size_t i = 0; + // if constexpr to silence the warning about shift by 8 + if constexpr (sizeof(IntType) > 1) { + for (; i + 1 < sizeof(IntType); ++i) { + u[i] = (std::uint8_t) v; + v >>= 8; + } + } + u[i] = (std::uint8_t) v; + + stream.write(reinterpret_cast(u), sizeof(IntType)); + } + } + + + // read_little_endian(s, out, N) : read integers in bulk from a little indian stream. + // This reads N integers from stream s and put them in array out. + template + inline void read_little_endian(std::istream& stream, IntType* out, std::size_t count) { + if (IsLittleEndian) + stream.read(reinterpret_cast(out), sizeof(IntType) * count); + else + for (std::size_t i = 0; i < count; ++i) out[i] = read_little_endian(stream); + } + + + // write_little_endian(s, values, N) : write integers in bulk to a little indian stream. + // This takes N integers from array values and writes them on stream s. + template inline void + write_little_endian(std::ostream& stream, const IntType* values, std::size_t count) { + if (IsLittleEndian) + stream.write(reinterpret_cast(values), sizeof(IntType) * count); + else + for (std::size_t i = 0; i < count; ++i) write_little_endian(stream, values[i]); + } + + + // read_leb_128(s, out, N) : read N signed integers from the stream s, putting them in + // the array out. The stream is assumed to be compressed using the signed LEB128 format. + // See https://en.wikipedia.org/wiki/LEB128 for a description of the compression scheme. + template + inline void read_leb_128(std::istream& stream, IntType* out, std::size_t count) { + + // Check the presence of our LEB128 magic string + char leb128MagicString[Leb128MagicStringSize]; + stream.read(leb128MagicString, Leb128MagicStringSize); + assert(strncmp(Leb128MagicString, leb128MagicString, Leb128MagicStringSize) == 0); + + static_assert(std::is_signed_v, "Not implemented for unsigned types"); + + const std::uint32_t BUF_SIZE = 4096; + std::uint8_t buf[BUF_SIZE]; + + auto bytes_left = read_little_endian(stream); + + std::uint32_t buf_pos = BUF_SIZE; + for (std::size_t i = 0; i < count; ++i) { + IntType result = 0; + size_t shift = 0; + do { + if (buf_pos == BUF_SIZE) { + stream.read(reinterpret_cast(buf), std::min(bytes_left, BUF_SIZE)); + buf_pos = 0; + } + + std::uint8_t byte = buf[buf_pos++]; + --bytes_left; + result |= (byte & 0x7f) << shift; + shift += 7; + + if ((byte & 0x80) == 0) { + out[i] = (sizeof(IntType) * 8 <= shift || (byte & 0x40) == 0) ? + result : + result | ~((1 << shift) - 1); + break; + } + } while (shift < sizeof(IntType) * 8); + } + + assert(bytes_left == 0); + } + + + // write_leb_128(s, values, N) : write signed integers to a stream with LEB128 compression. + // This takes N integers from array values, compress them with the LEB128 algorithm and + // writes the result on the stream s. + // See https://en.wikipedia.org/wiki/LEB128 for a description of the compression scheme. + template + inline void write_leb_128(std::ostream& stream, const IntType* values, std::size_t count) { + + // Write our LEB128 magic string + stream.write(Leb128MagicString, Leb128MagicStringSize); + + static_assert(std::is_signed_v, "Not implemented for unsigned types"); + + std::uint32_t byte_count = 0; + for (std::size_t i = 0; i < count; ++i) { + IntType value = values[i]; + std::uint8_t byte; + do { + byte = value & 0x7f; + value >>= 7; + ++byte_count; + } while ((byte & 0x40) == 0 ? value != 0 : value != -1); + } + + write_little_endian(stream, byte_count); + + const std::uint32_t BUF_SIZE = 4096; + std::uint8_t buf[BUF_SIZE]; + std::uint32_t buf_pos = 0; + + auto flush = [&]() { + if (buf_pos > 0) { + stream.write(reinterpret_cast(buf), buf_pos); + buf_pos = 0; } - } - u[i] = (std::uint8_t)v; - - stream.write(reinterpret_cast(u), sizeof(IntType)); - } - } - - - // read_little_endian(s, out, N) : read integers in bulk from a little indian stream. - // This reads N integers from stream s and put them in array out. - template - inline void read_little_endian(std::istream& stream, IntType* out, std::size_t count) { - if (IsLittleEndian) - stream.read(reinterpret_cast(out), sizeof(IntType) * count); - else - for (std::size_t i = 0; i < count; ++i) - out[i] = read_little_endian(stream); - } - - - // write_little_endian(s, values, N) : write integers in bulk to a little indian stream. - // This takes N integers from array values and writes them on stream s. - template - inline void write_little_endian(std::ostream& stream, const IntType* values, std::size_t count) { - if (IsLittleEndian) - stream.write(reinterpret_cast(values), sizeof(IntType) * count); - else - for (std::size_t i = 0; i < count; ++i) - write_little_endian(stream, values[i]); - } - - - // read_leb_128(s, out, N) : read N signed integers from the stream s, putting them in - // the array out. The stream is assumed to be compressed using the signed LEB128 format. - // See https://en.wikipedia.org/wiki/LEB128 for a description of the compression scheme. - template - inline void read_leb_128(std::istream& stream, IntType* out, std::size_t count) { - - // Check the presence of our LEB128 magic string - char leb128MagicString[Leb128MagicStringSize]; - stream.read(leb128MagicString, Leb128MagicStringSize); - assert(strncmp(Leb128MagicString, leb128MagicString, Leb128MagicStringSize) == 0); - - static_assert(std::is_signed_v, "Not implemented for unsigned types"); - - const std::uint32_t BUF_SIZE = 4096; - std::uint8_t buf[BUF_SIZE]; - - auto bytes_left = read_little_endian(stream); - - std::uint32_t buf_pos = BUF_SIZE; - for (std::size_t i = 0; i < count; ++i) - { - IntType result = 0; - size_t shift = 0; - do - { - if (buf_pos == BUF_SIZE) - { - stream.read(reinterpret_cast(buf), std::min(bytes_left, BUF_SIZE)); - buf_pos = 0; - } - - std::uint8_t byte = buf[buf_pos++]; - --bytes_left; - result |= (byte & 0x7f) << shift; - shift += 7; - - if ((byte & 0x80) == 0) - { - out[i] = (sizeof(IntType) * 8 <= shift || (byte & 0x40) == 0) ? result - : result | ~((1 << shift) - 1); - break; - } - } - while (shift < sizeof(IntType) * 8); - } - - assert(bytes_left == 0); - } - - - // write_leb_128(s, values, N) : write signed integers to a stream with LEB128 compression. - // This takes N integers from array values, compress them with the LEB128 algorithm and - // writes the result on the stream s. - // See https://en.wikipedia.org/wiki/LEB128 for a description of the compression scheme. - template - inline void write_leb_128(std::ostream& stream, const IntType* values, std::size_t count) { - - // Write our LEB128 magic string - stream.write(Leb128MagicString, Leb128MagicStringSize); - - static_assert(std::is_signed_v, "Not implemented for unsigned types"); - - std::uint32_t byte_count = 0; - for (std::size_t i = 0; i < count; ++i) - { - IntType value = values[i]; - std::uint8_t byte; - do - { - byte = value & 0x7f; - value >>= 7; - ++byte_count; - } - while ((byte & 0x40) == 0 ? value != 0 : value != -1); - } - - write_little_endian(stream, byte_count); - - const std::uint32_t BUF_SIZE = 4096; - std::uint8_t buf[BUF_SIZE]; - std::uint32_t buf_pos = 0; - - auto flush = [&]() { - if (buf_pos > 0) - { - stream.write(reinterpret_cast(buf), buf_pos); - buf_pos = 0; - } - }; - - auto write = [&](std::uint8_t byte) { - buf[buf_pos++] = byte; - if (buf_pos == BUF_SIZE) - flush(); - }; - - for (std::size_t i = 0; i < count; ++i) - { - IntType value = values[i]; - while (true) - { - std::uint8_t byte = value & 0x7f; - value >>= 7; - if ((byte & 0x40) == 0 ? value == 0 : value == -1) - { - write(byte); - break; - } - write(byte | 0x80); - } - } - - flush(); - } + }; + + auto write = [&](std::uint8_t byte) { + buf[buf_pos++] = byte; + if (buf_pos == BUF_SIZE) flush(); + }; + + for (std::size_t i = 0; i < count; ++i) { + IntType value = values[i]; + while (true) { + std::uint8_t byte = value & 0x7f; + value >>= 7; + if ((byte & 0x40) == 0 ? value == 0 : value == -1) { + write(byte); + break; + } + write(byte | 0x80); + } + } + + flush(); + } } // namespace Stockfish::Eval::NNUE -#endif // #ifndef NNUE_COMMON_H_INCLUDED +#endif // #ifndef NNUE_COMMON_H_INCLUDED diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index 0af0ed96cc5..54be71293e4 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -36,646 +36,614 @@ namespace Stockfish::Eval::NNUE { - using BiasType = std::int16_t; - using WeightType = std::int16_t; - using PSQTWeightType = std::int32_t; - - // If vector instructions are enabled, we update and refresh the - // accumulator tile by tile such that each tile fits in the CPU's - // vector registers. - #define VECTOR - - static_assert(PSQTBuckets % 8 == 0, - "Per feature PSQT values cannot be processed at granularity lower than 8 at a time."); - - #ifdef USE_AVX512 - using vec_t = __m512i; - using psqt_vec_t = __m256i; - #define vec_load(a) _mm512_load_si512(a) - #define vec_store(a,b) _mm512_store_si512(a,b) - #define vec_add_16(a,b) _mm512_add_epi16(a,b) - #define vec_sub_16(a,b) _mm512_sub_epi16(a,b) - #define vec_mul_16(a,b) _mm512_mullo_epi16(a,b) - #define vec_zero() _mm512_setzero_epi32() - #define vec_set_16(a) _mm512_set1_epi16(a) - #define vec_max_16(a,b) _mm512_max_epi16(a,b) - #define vec_min_16(a,b) _mm512_min_epi16(a,b) - inline vec_t vec_msb_pack_16(vec_t a, vec_t b){ - vec_t compacted = _mm512_packs_epi16(_mm512_srli_epi16(a,7),_mm512_srli_epi16(b,7)); - return _mm512_permutexvar_epi64(_mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7), compacted); - } - #define vec_load_psqt(a) _mm256_load_si256(a) - #define vec_store_psqt(a,b) _mm256_store_si256(a,b) - #define vec_add_psqt_32(a,b) _mm256_add_epi32(a,b) - #define vec_sub_psqt_32(a,b) _mm256_sub_epi32(a,b) - #define vec_zero_psqt() _mm256_setzero_si256() - #define NumRegistersSIMD 32 - #define MaxChunkSize 64 - - #elif USE_AVX2 - using vec_t = __m256i; - using psqt_vec_t = __m256i; - #define vec_load(a) _mm256_load_si256(a) - #define vec_store(a,b) _mm256_store_si256(a,b) - #define vec_add_16(a,b) _mm256_add_epi16(a,b) - #define vec_sub_16(a,b) _mm256_sub_epi16(a,b) - #define vec_mul_16(a,b) _mm256_mullo_epi16(a,b) - #define vec_zero() _mm256_setzero_si256() - #define vec_set_16(a) _mm256_set1_epi16(a) - #define vec_max_16(a,b) _mm256_max_epi16(a,b) - #define vec_min_16(a,b) _mm256_min_epi16(a,b) - inline vec_t vec_msb_pack_16(vec_t a, vec_t b){ - vec_t compacted = _mm256_packs_epi16(_mm256_srli_epi16(a,7), _mm256_srli_epi16(b,7)); - return _mm256_permute4x64_epi64(compacted, 0b11011000); - } - #define vec_load_psqt(a) _mm256_load_si256(a) - #define vec_store_psqt(a,b) _mm256_store_si256(a,b) - #define vec_add_psqt_32(a,b) _mm256_add_epi32(a,b) - #define vec_sub_psqt_32(a,b) _mm256_sub_epi32(a,b) - #define vec_zero_psqt() _mm256_setzero_si256() - #define NumRegistersSIMD 16 - #define MaxChunkSize 32 - - #elif USE_SSE2 - using vec_t = __m128i; - using psqt_vec_t = __m128i; - #define vec_load(a) (*(a)) - #define vec_store(a,b) *(a)=(b) - #define vec_add_16(a,b) _mm_add_epi16(a,b) - #define vec_sub_16(a,b) _mm_sub_epi16(a,b) - #define vec_mul_16(a,b) _mm_mullo_epi16(a,b) - #define vec_zero() _mm_setzero_si128() - #define vec_set_16(a) _mm_set1_epi16(a) - #define vec_max_16(a,b) _mm_max_epi16(a,b) - #define vec_min_16(a,b) _mm_min_epi16(a,b) - #define vec_msb_pack_16(a,b) _mm_packs_epi16(_mm_srli_epi16(a,7),_mm_srli_epi16(b,7)) - #define vec_load_psqt(a) (*(a)) - #define vec_store_psqt(a,b) *(a)=(b) - #define vec_add_psqt_32(a,b) _mm_add_epi32(a,b) - #define vec_sub_psqt_32(a,b) _mm_sub_epi32(a,b) - #define vec_zero_psqt() _mm_setzero_si128() - #define NumRegistersSIMD (Is64Bit ? 16 : 8) - #define MaxChunkSize 16 - - #elif USE_MMX - using vec_t = __m64; - using psqt_vec_t = __m64; - #define vec_load(a) (*(a)) - #define vec_store(a,b) *(a)=(b) - #define vec_add_16(a,b) _mm_add_pi16(a,b) - #define vec_sub_16(a,b) _mm_sub_pi16(a,b) - #define vec_mul_16(a,b) _mm_mullo_pi16(a,b) - #define vec_zero() _mm_setzero_si64() - #define vec_set_16(a) _mm_set1_pi16(a) - inline vec_t vec_max_16(vec_t a,vec_t b){ - vec_t comparison = _mm_cmpgt_pi16(a,b); - return _mm_or_si64(_mm_and_si64(comparison, a), _mm_andnot_si64(comparison, b)); - } - inline vec_t vec_min_16(vec_t a,vec_t b){ - vec_t comparison = _mm_cmpgt_pi16(a,b); - return _mm_or_si64(_mm_and_si64(comparison, b), _mm_andnot_si64(comparison, a)); - } - #define vec_msb_pack_16(a,b) _mm_packs_pi16(_mm_srli_pi16(a,7),_mm_srli_pi16(b,7)) - #define vec_load_psqt(a) (*(a)) - #define vec_store_psqt(a,b) *(a)=(b) - #define vec_add_psqt_32(a,b) _mm_add_pi32(a,b) - #define vec_sub_psqt_32(a,b) _mm_sub_pi32(a,b) - #define vec_zero_psqt() _mm_setzero_si64() - #define vec_cleanup() _mm_empty() - #define NumRegistersSIMD 8 - #define MaxChunkSize 8 - - #elif USE_NEON - using vec_t = int16x8_t; - using psqt_vec_t = int32x4_t; - #define vec_load(a) (*(a)) - #define vec_store(a,b) *(a)=(b) - #define vec_add_16(a,b) vaddq_s16(a,b) - #define vec_sub_16(a,b) vsubq_s16(a,b) - #define vec_mul_16(a,b) vmulq_s16(a,b) - #define vec_zero() vec_t{0} - #define vec_set_16(a) vdupq_n_s16(a) - #define vec_max_16(a,b) vmaxq_s16(a,b) - #define vec_min_16(a,b) vminq_s16(a,b) - inline vec_t vec_msb_pack_16(vec_t a, vec_t b){ - const int8x8_t shifta = vshrn_n_s16(a, 7); - const int8x8_t shiftb = vshrn_n_s16(b, 7); - const int8x16_t compacted = vcombine_s8(shifta,shiftb); - return *reinterpret_cast (&compacted); - } - #define vec_load_psqt(a) (*(a)) - #define vec_store_psqt(a,b) *(a)=(b) - #define vec_add_psqt_32(a,b) vaddq_s32(a,b) - #define vec_sub_psqt_32(a,b) vsubq_s32(a,b) - #define vec_zero_psqt() psqt_vec_t{0} - #define NumRegistersSIMD 16 - #define MaxChunkSize 16 - - #else - #undef VECTOR - - #endif - - - #ifdef VECTOR - - // Compute optimal SIMD register count for feature transformer accumulation. - - // We use __m* types as template arguments, which causes GCC to emit warnings - // about losing some attribute information. This is irrelevant to us as we - // only take their size, so the following pragma are harmless. - #if defined(__GNUC__) - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Wignored-attributes" - #endif - - template - static constexpr int BestRegisterCount() - { - #define RegisterSize sizeof(SIMDRegisterType) - #define LaneSize sizeof(LaneType) - - static_assert(RegisterSize >= LaneSize); - static_assert(MaxRegisters <= NumRegistersSIMD); - static_assert(MaxRegisters > 0); - static_assert(NumRegistersSIMD > 0); - static_assert(RegisterSize % LaneSize == 0); - static_assert((NumLanes * LaneSize) % RegisterSize == 0); - - const int ideal = (NumLanes * LaneSize) / RegisterSize; - if (ideal <= MaxRegisters) - return ideal; - - // Look for the largest divisor of the ideal register count that is smaller than MaxRegisters - for (int divisor = MaxRegisters; divisor > 1; --divisor) - if (ideal % divisor == 0) - return divisor; - - return 1; - } - - static constexpr int NumRegs = BestRegisterCount(); - static constexpr int NumPsqtRegs = BestRegisterCount(); - #if defined(__GNUC__) - #pragma GCC diagnostic pop - #endif - #endif - - - - // Input feature converter - class FeatureTransformer { - - private: - // Number of output dimensions for one side - static constexpr IndexType HalfDimensions = TransformedFeatureDimensions; - - #ifdef VECTOR - static constexpr IndexType TileHeight = NumRegs * sizeof(vec_t) / 2; - static constexpr IndexType PsqtTileHeight = NumPsqtRegs * sizeof(psqt_vec_t) / 4; - static_assert(HalfDimensions % TileHeight == 0, "TileHeight must divide HalfDimensions"); - static_assert(PSQTBuckets % PsqtTileHeight == 0, "PsqtTileHeight must divide PSQTBuckets"); + using BiasType = std::int16_t; + using WeightType = std::int16_t; + using PSQTWeightType = std::int32_t; + +// If vector instructions are enabled, we update and refresh the +// accumulator tile by tile such that each tile fits in the CPU's +// vector registers. +#define VECTOR + + static_assert( + PSQTBuckets % 8 == 0, + "Per feature PSQT values cannot be processed at granularity lower than 8 at a time."); + +#ifdef USE_AVX512 + using vec_t = __m512i; + using psqt_vec_t = __m256i; + #define vec_load(a) _mm512_load_si512(a) + #define vec_store(a, b) _mm512_store_si512(a, b) + #define vec_add_16(a, b) _mm512_add_epi16(a, b) + #define vec_sub_16(a, b) _mm512_sub_epi16(a, b) + #define vec_mul_16(a, b) _mm512_mullo_epi16(a, b) + #define vec_zero() _mm512_setzero_epi32() + #define vec_set_16(a) _mm512_set1_epi16(a) + #define vec_max_16(a, b) _mm512_max_epi16(a, b) + #define vec_min_16(a, b) _mm512_min_epi16(a, b) + inline vec_t vec_msb_pack_16(vec_t a, vec_t b) { + vec_t compacted = _mm512_packs_epi16(_mm512_srli_epi16(a, 7), _mm512_srli_epi16(b, 7)); + return _mm512_permutexvar_epi64(_mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7), compacted); + } + #define vec_load_psqt(a) _mm256_load_si256(a) + #define vec_store_psqt(a, b) _mm256_store_si256(a, b) + #define vec_add_psqt_32(a, b) _mm256_add_epi32(a, b) + #define vec_sub_psqt_32(a, b) _mm256_sub_epi32(a, b) + #define vec_zero_psqt() _mm256_setzero_si256() + #define NumRegistersSIMD 32 + #define MaxChunkSize 64 + +#elif USE_AVX2 + using vec_t = __m256i; + using psqt_vec_t = __m256i; + #define vec_load(a) _mm256_load_si256(a) + #define vec_store(a, b) _mm256_store_si256(a, b) + #define vec_add_16(a, b) _mm256_add_epi16(a, b) + #define vec_sub_16(a, b) _mm256_sub_epi16(a, b) + #define vec_mul_16(a, b) _mm256_mullo_epi16(a, b) + #define vec_zero() _mm256_setzero_si256() + #define vec_set_16(a) _mm256_set1_epi16(a) + #define vec_max_16(a, b) _mm256_max_epi16(a, b) + #define vec_min_16(a, b) _mm256_min_epi16(a, b) + inline vec_t vec_msb_pack_16(vec_t a, vec_t b) { + vec_t compacted = _mm256_packs_epi16(_mm256_srli_epi16(a, 7), _mm256_srli_epi16(b, 7)); + return _mm256_permute4x64_epi64(compacted, 0b11011000); + } + #define vec_load_psqt(a) _mm256_load_si256(a) + #define vec_store_psqt(a, b) _mm256_store_si256(a, b) + #define vec_add_psqt_32(a, b) _mm256_add_epi32(a, b) + #define vec_sub_psqt_32(a, b) _mm256_sub_epi32(a, b) + #define vec_zero_psqt() _mm256_setzero_si256() + #define NumRegistersSIMD 16 + #define MaxChunkSize 32 + +#elif USE_SSE2 + using vec_t = __m128i; + using psqt_vec_t = __m128i; + #define vec_load(a) (*(a)) + #define vec_store(a, b) *(a) = (b) + #define vec_add_16(a, b) _mm_add_epi16(a, b) + #define vec_sub_16(a, b) _mm_sub_epi16(a, b) + #define vec_mul_16(a, b) _mm_mullo_epi16(a, b) + #define vec_zero() _mm_setzero_si128() + #define vec_set_16(a) _mm_set1_epi16(a) + #define vec_max_16(a, b) _mm_max_epi16(a, b) + #define vec_min_16(a, b) _mm_min_epi16(a, b) + #define vec_msb_pack_16(a, b) _mm_packs_epi16(_mm_srli_epi16(a, 7), _mm_srli_epi16(b, 7)) + #define vec_load_psqt(a) (*(a)) + #define vec_store_psqt(a, b) *(a) = (b) + #define vec_add_psqt_32(a, b) _mm_add_epi32(a, b) + #define vec_sub_psqt_32(a, b) _mm_sub_epi32(a, b) + #define vec_zero_psqt() _mm_setzero_si128() + #define NumRegistersSIMD (Is64Bit ? 16 : 8) + #define MaxChunkSize 16 + +#elif USE_MMX + using vec_t = __m64; + using psqt_vec_t = __m64; + #define vec_load(a) (*(a)) + #define vec_store(a, b) *(a) = (b) + #define vec_add_16(a, b) _mm_add_pi16(a, b) + #define vec_sub_16(a, b) _mm_sub_pi16(a, b) + #define vec_mul_16(a, b) _mm_mullo_pi16(a, b) + #define vec_zero() _mm_setzero_si64() + #define vec_set_16(a) _mm_set1_pi16(a) + inline vec_t vec_max_16(vec_t a, vec_t b) { + vec_t comparison = _mm_cmpgt_pi16(a, b); + return _mm_or_si64(_mm_and_si64(comparison, a), _mm_andnot_si64(comparison, b)); + } + inline vec_t vec_min_16(vec_t a, vec_t b) { + vec_t comparison = _mm_cmpgt_pi16(a, b); + return _mm_or_si64(_mm_and_si64(comparison, b), _mm_andnot_si64(comparison, a)); + } + #define vec_msb_pack_16(a, b) _mm_packs_pi16(_mm_srli_pi16(a, 7), _mm_srli_pi16(b, 7)) + #define vec_load_psqt(a) (*(a)) + #define vec_store_psqt(a, b) *(a) = (b) + #define vec_add_psqt_32(a, b) _mm_add_pi32(a, b) + #define vec_sub_psqt_32(a, b) _mm_sub_pi32(a, b) + #define vec_zero_psqt() _mm_setzero_si64() + #define vec_cleanup() _mm_empty() + #define NumRegistersSIMD 8 + #define MaxChunkSize 8 + +#elif USE_NEON + using vec_t = int16x8_t; + using psqt_vec_t = int32x4_t; + #define vec_load(a) (*(a)) + #define vec_store(a, b) *(a) = (b) + #define vec_add_16(a, b) vaddq_s16(a, b) + #define vec_sub_16(a, b) vsubq_s16(a, b) + #define vec_mul_16(a, b) vmulq_s16(a, b) + #define vec_zero() \ + vec_t { 0 } + #define vec_set_16(a) vdupq_n_s16(a) + #define vec_max_16(a, b) vmaxq_s16(a, b) + #define vec_min_16(a, b) vminq_s16(a, b) + inline vec_t vec_msb_pack_16(vec_t a, vec_t b) { + const int8x8_t shifta = vshrn_n_s16(a, 7); + const int8x8_t shiftb = vshrn_n_s16(b, 7); + const int8x16_t compacted = vcombine_s8(shifta, shiftb); + return *reinterpret_cast(&compacted); + } + #define vec_load_psqt(a) (*(a)) + #define vec_store_psqt(a, b) *(a) = (b) + #define vec_add_psqt_32(a, b) vaddq_s32(a, b) + #define vec_sub_psqt_32(a, b) vsubq_s32(a, b) + #define vec_zero_psqt() \ + psqt_vec_t { 0 } + #define NumRegistersSIMD 16 + #define MaxChunkSize 16 + +#else + #undef VECTOR + +#endif + + +#ifdef VECTOR + + // Compute optimal SIMD register count for feature transformer accumulation. + + // We use __m* types as template arguments, which causes GCC to emit warnings + // about losing some attribute information. This is irrelevant to us as we + // only take their size, so the following pragma are harmless. + #if defined(__GNUC__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wignored-attributes" #endif - public: - // Output type - using OutputType = TransformedFeatureType; + template + static constexpr int BestRegisterCount() { + #define RegisterSize sizeof(SIMDRegisterType) + #define LaneSize sizeof(LaneType) - // Number of input/output dimensions - static constexpr IndexType InputDimensions = FeatureSet::Dimensions; - static constexpr IndexType OutputDimensions = HalfDimensions; + static_assert(RegisterSize >= LaneSize); + static_assert(MaxRegisters <= NumRegistersSIMD); + static_assert(MaxRegisters > 0); + static_assert(NumRegistersSIMD > 0); + static_assert(RegisterSize % LaneSize == 0); + static_assert((NumLanes * LaneSize) % RegisterSize == 0); - // Size of forward propagation buffer - static constexpr std::size_t BufferSize = - OutputDimensions * sizeof(OutputType); + const int ideal = (NumLanes * LaneSize) / RegisterSize; + if (ideal <= MaxRegisters) return ideal; - // Hash value embedded in the evaluation file - static constexpr std::uint32_t get_hash_value() { - return FeatureSet::HashValue ^ (OutputDimensions * 2); + // Look for the largest divisor of the ideal register count that is smaller than MaxRegisters + for (int divisor = MaxRegisters; divisor > 1; --divisor) + if (ideal % divisor == 0) return divisor; + + return 1; } - // Read network parameters - bool read_parameters(std::istream& stream) { + static constexpr int NumRegs = + BestRegisterCount(); + static constexpr int NumPsqtRegs = + BestRegisterCount(); + #if defined(__GNUC__) + #pragma GCC diagnostic pop + #endif +#endif - read_leb_128(stream, biases , HalfDimensions ); - read_leb_128(stream, weights , HalfDimensions * InputDimensions); - read_leb_128(stream, psqtWeights, PSQTBuckets * InputDimensions); - return !stream.fail(); - } + // Input feature converter + class FeatureTransformer { - // Write network parameters - bool write_parameters(std::ostream& stream) const { + private: + // Number of output dimensions for one side + static constexpr IndexType HalfDimensions = TransformedFeatureDimensions; - write_leb_128(stream, biases , HalfDimensions ); - write_leb_128(stream, weights , HalfDimensions * InputDimensions); - write_leb_128(stream, psqtWeights, PSQTBuckets * InputDimensions); +#ifdef VECTOR + static constexpr IndexType TileHeight = NumRegs * sizeof(vec_t) / 2; + static constexpr IndexType PsqtTileHeight = NumPsqtRegs * sizeof(psqt_vec_t) / 4; + static_assert(HalfDimensions % TileHeight == 0, "TileHeight must divide HalfDimensions"); + static_assert(PSQTBuckets % PsqtTileHeight == 0, "PsqtTileHeight must divide PSQTBuckets"); +#endif - return !stream.fail(); - } + public: + // Output type + using OutputType = TransformedFeatureType; - // Convert input features - std::int32_t transform(const Position& pos, OutputType* output, int bucket) const { - update_accumulator(pos); - update_accumulator(pos); + // Number of input/output dimensions + static constexpr IndexType InputDimensions = FeatureSet::Dimensions; + static constexpr IndexType OutputDimensions = HalfDimensions; - const Color perspectives[2] = {pos.side_to_move(), ~pos.side_to_move()}; - const auto& accumulation = pos.state()->accumulator.accumulation; - const auto& psqtAccumulation = pos.state()->accumulator.psqtAccumulation; + // Size of forward propagation buffer + static constexpr std::size_t BufferSize = OutputDimensions * sizeof(OutputType); - const auto psqt = ( - psqtAccumulation[perspectives[0]][bucket] - - psqtAccumulation[perspectives[1]][bucket] - ) / 2; + // Hash value embedded in the evaluation file + static constexpr std::uint32_t get_hash_value() { + return FeatureSet::HashValue ^ (OutputDimensions * 2); + } + // Read network parameters + bool read_parameters(std::istream& stream) { - for (IndexType p = 0; p < 2; ++p) - { - const IndexType offset = (HalfDimensions / 2) * p; + read_leb_128(stream, biases, HalfDimensions); + read_leb_128(stream, weights, HalfDimensions * InputDimensions); + read_leb_128(stream, psqtWeights, PSQTBuckets * InputDimensions); + + return !stream.fail(); + } + + // Write network parameters + bool write_parameters(std::ostream& stream) const { + + write_leb_128(stream, biases, HalfDimensions); + write_leb_128(stream, weights, HalfDimensions * InputDimensions); + write_leb_128(stream, psqtWeights, PSQTBuckets * InputDimensions); + + return !stream.fail(); + } + + // Convert input features + std::int32_t transform(const Position& pos, OutputType* output, int bucket) const { + update_accumulator(pos); + update_accumulator(pos); + + const Color perspectives[2] = {pos.side_to_move(), ~pos.side_to_move()}; + const auto& accumulation = pos.state()->accumulator.accumulation; + const auto& psqtAccumulation = pos.state()->accumulator.psqtAccumulation; + + const auto psqt = (psqtAccumulation[perspectives[0]][bucket] - + psqtAccumulation[perspectives[1]][bucket]) / + 2; + + + for (IndexType p = 0; p < 2; ++p) { + const IndexType offset = (HalfDimensions / 2) * p; #if defined(VECTOR) - constexpr IndexType OutputChunkSize = MaxChunkSize; - static_assert((HalfDimensions / 2) % OutputChunkSize == 0); - constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize; + constexpr IndexType OutputChunkSize = MaxChunkSize; + static_assert((HalfDimensions / 2) % OutputChunkSize == 0); + constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize; - vec_t Zero = vec_zero(); - vec_t One = vec_set_16(127); + vec_t Zero = vec_zero(); + vec_t One = vec_set_16(127); - const vec_t* in0 = reinterpret_cast(&(accumulation[perspectives[p]][0])); - const vec_t* in1 = reinterpret_cast(&(accumulation[perspectives[p]][HalfDimensions / 2])); - vec_t* out = reinterpret_cast< vec_t*>(output + offset); + const vec_t* in0 = + reinterpret_cast(&(accumulation[perspectives[p]][0])); + const vec_t* in1 = reinterpret_cast( + &(accumulation[perspectives[p]][HalfDimensions / 2])); + vec_t* out = reinterpret_cast(output + offset); - for (IndexType j = 0; j < NumOutputChunks; j += 1) - { - const vec_t sum0a = vec_max_16(vec_min_16(in0[j * 2 + 0], One), Zero); - const vec_t sum0b = vec_max_16(vec_min_16(in0[j * 2 + 1], One), Zero); - const vec_t sum1a = vec_max_16(vec_min_16(in1[j * 2 + 0], One), Zero); - const vec_t sum1b = vec_max_16(vec_min_16(in1[j * 2 + 1], One), Zero); + for (IndexType j = 0; j < NumOutputChunks; j += 1) { + const vec_t sum0a = vec_max_16(vec_min_16(in0[j * 2 + 0], One), Zero); + const vec_t sum0b = vec_max_16(vec_min_16(in0[j * 2 + 1], One), Zero); + const vec_t sum1a = vec_max_16(vec_min_16(in1[j * 2 + 0], One), Zero); + const vec_t sum1b = vec_max_16(vec_min_16(in1[j * 2 + 1], One), Zero); - const vec_t pa = vec_mul_16(sum0a, sum1a); - const vec_t pb = vec_mul_16(sum0b, sum1b); + const vec_t pa = vec_mul_16(sum0a, sum1a); + const vec_t pb = vec_mul_16(sum0b, sum1b); - out[j] = vec_msb_pack_16(pa, pb); - } + out[j] = vec_msb_pack_16(pa, pb); + } #else - for (IndexType j = 0; j < HalfDimensions / 2; ++j) { - BiasType sum0 = accumulation[static_cast(perspectives[p])][j + 0]; - BiasType sum1 = accumulation[static_cast(perspectives[p])][j + HalfDimensions / 2]; - sum0 = std::max(0, std::min(127, sum0)); - sum1 = std::max(0, std::min(127, sum1)); - output[offset + j] = static_cast(sum0 * sum1 / 128); - } + for (IndexType j = 0; j < HalfDimensions / 2; ++j) { + BiasType sum0 = accumulation[static_cast(perspectives[p])][j + 0]; + BiasType sum1 = + accumulation[static_cast(perspectives[p])][j + HalfDimensions / 2]; + sum0 = std::max(0, std::min(127, sum0)); + sum1 = std::max(0, std::min(127, sum1)); + output[offset + j] = static_cast(sum0 * sum1 / 128); + } #endif - } + } #if defined(vec_cleanup) - vec_cleanup(); + vec_cleanup(); #endif - return psqt; - } // end of function transform() + return psqt; + } // end of function transform() - void hint_common_access(const Position& pos) const { - hint_common_access_for_perspective(pos); - hint_common_access_for_perspective(pos); - } + void hint_common_access(const Position& pos) const { + hint_common_access_for_perspective(pos); + hint_common_access_for_perspective(pos); + } - private: - template - [[nodiscard]] std::pair try_find_computed_accumulator(const Position& pos) const { - // Look for a usable accumulator of an earlier position. We keep track - // of the estimated gain in terms of features to be added/subtracted. - StateInfo *st = pos.state(), *next = nullptr; - int gain = FeatureSet::refresh_cost(pos); - while (st->previous && !st->accumulator.computed[Perspective]) - { - // This governs when a full feature refresh is needed and how many - // updates are better than just one full refresh. - if ( FeatureSet::requires_refresh(st, Perspective) - || (gain -= FeatureSet::update_cost(st) + 1) < 0) - break; - next = st; - st = st->previous; - } - return { st, next }; - } + private: + template [[nodiscard]] std::pair + try_find_computed_accumulator(const Position& pos) const { + // Look for a usable accumulator of an earlier position. We keep track + // of the estimated gain in terms of features to be added/subtracted. + StateInfo *st = pos.state(), *next = nullptr; + int gain = FeatureSet::refresh_cost(pos); + while (st->previous && !st->accumulator.computed[Perspective]) { + // This governs when a full feature refresh is needed and how many + // updates are better than just one full refresh. + if (FeatureSet::requires_refresh(st, Perspective) || + (gain -= FeatureSet::update_cost(st) + 1) < 0) + break; + next = st; + st = st->previous; + } + return {st, next}; + } - // NOTE: The parameter states_to_update is an array of position states, ending with nullptr. - // All states must be sequential, that is states_to_update[i] must either be reachable - // by repeatedly applying ->previous from states_to_update[i+1] or states_to_update[i] == nullptr. - // computed_st must be reachable by repeatedly applying ->previous on states_to_update[0], if not nullptr. - template - void update_accumulator_incremental(const Position& pos, StateInfo* computed_st, StateInfo* states_to_update[N]) const { - static_assert(N > 0); - assert(states_to_update[N-1] == nullptr); + // NOTE: The parameter states_to_update is an array of position states, ending with nullptr. + // All states must be sequential, that is states_to_update[i] must either be reachable + // by repeatedly applying ->previous from states_to_update[i+1] or states_to_update[i] == nullptr. + // computed_st must be reachable by repeatedly applying ->previous on states_to_update[0], if not nullptr. + template + void update_accumulator_incremental(const Position& pos, StateInfo* computed_st, + StateInfo* states_to_update[N]) const { + static_assert(N > 0); + assert(states_to_update[N - 1] == nullptr); - #ifdef VECTOR - // Gcc-10.2 unnecessarily spills AVX2 registers if this array - // is defined in the VECTOR code below, once in each branch - vec_t acc[NumRegs]; - psqt_vec_t psqt[NumPsqtRegs]; - #endif +#ifdef VECTOR + // Gcc-10.2 unnecessarily spills AVX2 registers if this array + // is defined in the VECTOR code below, once in each branch + vec_t acc[NumRegs]; + psqt_vec_t psqt[NumPsqtRegs]; +#endif - if (states_to_update[0] == nullptr) - return; + if (states_to_update[0] == nullptr) return; - // Update incrementally going back through states_to_update. + // Update incrementally going back through states_to_update. - // Gather all features to be updated. - const Square ksq = pos.square(Perspective); + // Gather all features to be updated. + const Square ksq = pos.square(Perspective); - // The size must be enough to contain the largest possible update. - // That might depend on the feature set and generally relies on the - // feature set's update cost calculation to be correct and never - // allow updates with more added/removed features than MaxActiveDimensions. - FeatureSet::IndexList removed[N-1], added[N-1]; + // The size must be enough to contain the largest possible update. + // That might depend on the feature set and generally relies on the + // feature set's update cost calculation to be correct and never + // allow updates with more added/removed features than MaxActiveDimensions. + FeatureSet::IndexList removed[N - 1], added[N - 1]; - { - int i = N-2; // last potential state to update. Skip last element because it must be nullptr. - while (states_to_update[i] == nullptr) - --i; + { + int i = + N - + 2; // last potential state to update. Skip last element because it must be nullptr. + while (states_to_update[i] == nullptr) --i; - StateInfo *st2 = states_to_update[i]; + StateInfo* st2 = states_to_update[i]; - for (; i >= 0; --i) - { - states_to_update[i]->accumulator.computed[Perspective] = true; + for (; i >= 0; --i) { + states_to_update[i]->accumulator.computed[Perspective] = true; - StateInfo* end_state = i == 0 ? computed_st : states_to_update[i - 1]; + StateInfo* end_state = i == 0 ? computed_st : states_to_update[i - 1]; - for (; st2 != end_state; st2 = st2->previous) - FeatureSet::append_changed_indices( - ksq, st2->dirtyPiece, removed[i], added[i]); - } - } + for (; st2 != end_state; st2 = st2->previous) + FeatureSet::append_changed_indices(ksq, st2->dirtyPiece, + removed[i], added[i]); + } + } - StateInfo* st = computed_st; + StateInfo* st = computed_st; - // Now update the accumulators listed in states_to_update[], where the last element is a sentinel. + // Now update the accumulators listed in states_to_update[], where the last element is a sentinel. #ifdef VECTOR - for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j) - { - // Load accumulator - auto accTile = reinterpret_cast( - &st->accumulator.accumulation[Perspective][j * TileHeight]); - for (IndexType k = 0; k < NumRegs; ++k) - acc[k] = vec_load(&accTile[k]); - - for (IndexType i = 0; states_to_update[i]; ++i) - { - // Difference calculation for the deactivated features - for (const auto index : removed[i]) - { - const IndexType offset = HalfDimensions * index + j * TileHeight; - auto column = reinterpret_cast(&weights[offset]); - for (IndexType k = 0; k < NumRegs; ++k) - acc[k] = vec_sub_16(acc[k], column[k]); - } - - // Difference calculation for the activated features - for (const auto index : added[i]) - { - const IndexType offset = HalfDimensions * index + j * TileHeight; - auto column = reinterpret_cast(&weights[offset]); - for (IndexType k = 0; k < NumRegs; ++k) - acc[k] = vec_add_16(acc[k], column[k]); - } - - // Store accumulator - accTile = reinterpret_cast( - &states_to_update[i]->accumulator.accumulation[Perspective][j * TileHeight]); - for (IndexType k = 0; k < NumRegs; ++k) - vec_store(&accTile[k], acc[k]); - } - } - - for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j) - { - // Load accumulator - auto accTilePsqt = reinterpret_cast( - &st->accumulator.psqtAccumulation[Perspective][j * PsqtTileHeight]); - for (std::size_t k = 0; k < NumPsqtRegs; ++k) - psqt[k] = vec_load_psqt(&accTilePsqt[k]); - - for (IndexType i = 0; states_to_update[i]; ++i) - { - // Difference calculation for the deactivated features - for (const auto index : removed[i]) - { - const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight; - auto columnPsqt = reinterpret_cast(&psqtWeights[offset]); - for (std::size_t k = 0; k < NumPsqtRegs; ++k) - psqt[k] = vec_sub_psqt_32(psqt[k], columnPsqt[k]); - } - - // Difference calculation for the activated features - for (const auto index : added[i]) - { - const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight; - auto columnPsqt = reinterpret_cast(&psqtWeights[offset]); - for (std::size_t k = 0; k < NumPsqtRegs; ++k) - psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]); - } - - // Store accumulator - accTilePsqt = reinterpret_cast( - &states_to_update[i]->accumulator.psqtAccumulation[Perspective][j * PsqtTileHeight]); - for (std::size_t k = 0; k < NumPsqtRegs; ++k) - vec_store_psqt(&accTilePsqt[k], psqt[k]); - } - } + for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j) { + // Load accumulator + auto accTile = reinterpret_cast( + &st->accumulator.accumulation[Perspective][j * TileHeight]); + for (IndexType k = 0; k < NumRegs; ++k) acc[k] = vec_load(&accTile[k]); + + for (IndexType i = 0; states_to_update[i]; ++i) { + // Difference calculation for the deactivated features + for (const auto index : removed[i]) { + const IndexType offset = HalfDimensions * index + j * TileHeight; + auto column = reinterpret_cast(&weights[offset]); + for (IndexType k = 0; k < NumRegs; ++k) + acc[k] = vec_sub_16(acc[k], column[k]); + } + + // Difference calculation for the activated features + for (const auto index : added[i]) { + const IndexType offset = HalfDimensions * index + j * TileHeight; + auto column = reinterpret_cast(&weights[offset]); + for (IndexType k = 0; k < NumRegs; ++k) + acc[k] = vec_add_16(acc[k], column[k]); + } + + // Store accumulator + accTile = reinterpret_cast( + &states_to_update[i]->accumulator.accumulation[Perspective][j * TileHeight]); + for (IndexType k = 0; k < NumRegs; ++k) vec_store(&accTile[k], acc[k]); + } + } + + for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j) { + // Load accumulator + auto accTilePsqt = reinterpret_cast( + &st->accumulator.psqtAccumulation[Perspective][j * PsqtTileHeight]); + for (std::size_t k = 0; k < NumPsqtRegs; ++k) + psqt[k] = vec_load_psqt(&accTilePsqt[k]); + + for (IndexType i = 0; states_to_update[i]; ++i) { + // Difference calculation for the deactivated features + for (const auto index : removed[i]) { + const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight; + auto columnPsqt = reinterpret_cast(&psqtWeights[offset]); + for (std::size_t k = 0; k < NumPsqtRegs; ++k) + psqt[k] = vec_sub_psqt_32(psqt[k], columnPsqt[k]); + } + + // Difference calculation for the activated features + for (const auto index : added[i]) { + const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight; + auto columnPsqt = reinterpret_cast(&psqtWeights[offset]); + for (std::size_t k = 0; k < NumPsqtRegs; ++k) + psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]); + } + + // Store accumulator + accTilePsqt = reinterpret_cast( + &states_to_update[i] + ->accumulator.psqtAccumulation[Perspective][j * PsqtTileHeight]); + for (std::size_t k = 0; k < NumPsqtRegs; ++k) + vec_store_psqt(&accTilePsqt[k], psqt[k]); + } + } #else - for (IndexType i = 0; states_to_update[i]; ++i) - { - std::memcpy(states_to_update[i]->accumulator.accumulation[Perspective], - st->accumulator.accumulation[Perspective], - HalfDimensions * sizeof(BiasType)); + for (IndexType i = 0; states_to_update[i]; ++i) { + std::memcpy(states_to_update[i]->accumulator.accumulation[Perspective], + st->accumulator.accumulation[Perspective], + HalfDimensions * sizeof(BiasType)); - for (std::size_t k = 0; k < PSQTBuckets; ++k) - states_to_update[i]->accumulator.psqtAccumulation[Perspective][k] = st->accumulator.psqtAccumulation[Perspective][k]; + for (std::size_t k = 0; k < PSQTBuckets; ++k) + states_to_update[i]->accumulator.psqtAccumulation[Perspective][k] = + st->accumulator.psqtAccumulation[Perspective][k]; - st = states_to_update[i]; + st = states_to_update[i]; - // Difference calculation for the deactivated features - for (const auto index : removed[i]) - { - const IndexType offset = HalfDimensions * index; + // Difference calculation for the deactivated features + for (const auto index : removed[i]) { + const IndexType offset = HalfDimensions * index; - for (IndexType j = 0; j < HalfDimensions; ++j) - st->accumulator.accumulation[Perspective][j] -= weights[offset + j]; + for (IndexType j = 0; j < HalfDimensions; ++j) + st->accumulator.accumulation[Perspective][j] -= weights[offset + j]; - for (std::size_t k = 0; k < PSQTBuckets; ++k) - st->accumulator.psqtAccumulation[Perspective][k] -= psqtWeights[index * PSQTBuckets + k]; - } + for (std::size_t k = 0; k < PSQTBuckets; ++k) + st->accumulator.psqtAccumulation[Perspective][k] -= + psqtWeights[index * PSQTBuckets + k]; + } - // Difference calculation for the activated features - for (const auto index : added[i]) - { - const IndexType offset = HalfDimensions * index; + // Difference calculation for the activated features + for (const auto index : added[i]) { + const IndexType offset = HalfDimensions * index; - for (IndexType j = 0; j < HalfDimensions; ++j) - st->accumulator.accumulation[Perspective][j] += weights[offset + j]; + for (IndexType j = 0; j < HalfDimensions; ++j) + st->accumulator.accumulation[Perspective][j] += weights[offset + j]; - for (std::size_t k = 0; k < PSQTBuckets; ++k) - st->accumulator.psqtAccumulation[Perspective][k] += psqtWeights[index * PSQTBuckets + k]; - } - } + for (std::size_t k = 0; k < PSQTBuckets; ++k) + st->accumulator.psqtAccumulation[Perspective][k] += + psqtWeights[index * PSQTBuckets + k]; + } + } #endif - #if defined(USE_MMX) - _mm_empty(); - #endif - } +#if defined(USE_MMX) + _mm_empty(); +#endif + } + + template void update_accumulator_refresh(const Position& pos) const { +#ifdef VECTOR + // Gcc-10.2 unnecessarily spills AVX2 registers if this array + // is defined in the VECTOR code below, once in each branch + vec_t acc[NumRegs]; + psqt_vec_t psqt[NumPsqtRegs]; +#endif - template - void update_accumulator_refresh(const Position& pos) const { - #ifdef VECTOR - // Gcc-10.2 unnecessarily spills AVX2 registers if this array - // is defined in the VECTOR code below, once in each branch - vec_t acc[NumRegs]; - psqt_vec_t psqt[NumPsqtRegs]; - #endif - - // Refresh the accumulator - // Could be extracted to a separate function because it's done in 2 places, - // but it's unclear if compilers would correctly handle register allocation. - auto& accumulator = pos.state()->accumulator; - accumulator.computed[Perspective] = true; - FeatureSet::IndexList active; - FeatureSet::append_active_indices(pos, active); + // Refresh the accumulator + // Could be extracted to a separate function because it's done in 2 places, + // but it's unclear if compilers would correctly handle register allocation. + auto& accumulator = pos.state()->accumulator; + accumulator.computed[Perspective] = true; + FeatureSet::IndexList active; + FeatureSet::append_active_indices(pos, active); #ifdef VECTOR - for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j) - { - auto biasesTile = reinterpret_cast( - &biases[j * TileHeight]); - for (IndexType k = 0; k < NumRegs; ++k) - acc[k] = biasesTile[k]; - - for (const auto index : active) - { - const IndexType offset = HalfDimensions * index + j * TileHeight; - auto column = reinterpret_cast(&weights[offset]); - - for (unsigned k = 0; k < NumRegs; ++k) - acc[k] = vec_add_16(acc[k], column[k]); - } + for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j) { + auto biasesTile = reinterpret_cast(&biases[j * TileHeight]); + for (IndexType k = 0; k < NumRegs; ++k) acc[k] = biasesTile[k]; - auto accTile = reinterpret_cast( - &accumulator.accumulation[Perspective][j * TileHeight]); - for (unsigned k = 0; k < NumRegs; k++) - vec_store(&accTile[k], acc[k]); - } + for (const auto index : active) { + const IndexType offset = HalfDimensions * index + j * TileHeight; + auto column = reinterpret_cast(&weights[offset]); - for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j) - { - for (std::size_t k = 0; k < NumPsqtRegs; ++k) - psqt[k] = vec_zero_psqt(); + for (unsigned k = 0; k < NumRegs; ++k) acc[k] = vec_add_16(acc[k], column[k]); + } - for (const auto index : active) - { - const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight; - auto columnPsqt = reinterpret_cast(&psqtWeights[offset]); + auto accTile = + reinterpret_cast(&accumulator.accumulation[Perspective][j * TileHeight]); + for (unsigned k = 0; k < NumRegs; k++) vec_store(&accTile[k], acc[k]); + } - for (std::size_t k = 0; k < NumPsqtRegs; ++k) - psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]); - } + for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j) { + for (std::size_t k = 0; k < NumPsqtRegs; ++k) psqt[k] = vec_zero_psqt(); - auto accTilePsqt = reinterpret_cast( - &accumulator.psqtAccumulation[Perspective][j * PsqtTileHeight]); - for (std::size_t k = 0; k < NumPsqtRegs; ++k) - vec_store_psqt(&accTilePsqt[k], psqt[k]); - } + for (const auto index : active) { + const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight; + auto columnPsqt = reinterpret_cast(&psqtWeights[offset]); + + for (std::size_t k = 0; k < NumPsqtRegs; ++k) + psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]); + } + + auto accTilePsqt = reinterpret_cast( + &accumulator.psqtAccumulation[Perspective][j * PsqtTileHeight]); + for (std::size_t k = 0; k < NumPsqtRegs; ++k) + vec_store_psqt(&accTilePsqt[k], psqt[k]); + } #else - std::memcpy(accumulator.accumulation[Perspective], biases, - HalfDimensions * sizeof(BiasType)); + std::memcpy(accumulator.accumulation[Perspective], biases, + HalfDimensions * sizeof(BiasType)); - for (std::size_t k = 0; k < PSQTBuckets; ++k) - accumulator.psqtAccumulation[Perspective][k] = 0; + for (std::size_t k = 0; k < PSQTBuckets; ++k) + accumulator.psqtAccumulation[Perspective][k] = 0; - for (const auto index : active) - { - const IndexType offset = HalfDimensions * index; + for (const auto index : active) { + const IndexType offset = HalfDimensions * index; - for (IndexType j = 0; j < HalfDimensions; ++j) - accumulator.accumulation[Perspective][j] += weights[offset + j]; + for (IndexType j = 0; j < HalfDimensions; ++j) + accumulator.accumulation[Perspective][j] += weights[offset + j]; - for (std::size_t k = 0; k < PSQTBuckets; ++k) - accumulator.psqtAccumulation[Perspective][k] += psqtWeights[index * PSQTBuckets + k]; - } + for (std::size_t k = 0; k < PSQTBuckets; ++k) + accumulator.psqtAccumulation[Perspective][k] += + psqtWeights[index * PSQTBuckets + k]; + } #endif - #if defined(USE_MMX) - _mm_empty(); - #endif - } +#if defined(USE_MMX) + _mm_empty(); +#endif + } - template - void hint_common_access_for_perspective(const Position& pos) const { - - // Works like update_accumulator, but performs less work. - // Updates ONLY the accumulator for pos. - - // Look for a usable accumulator of an earlier position. We keep track - // of the estimated gain in terms of features to be added/subtracted. - // Fast early exit. - if (pos.state()->accumulator.computed[Perspective]) - return; - - auto [oldest_st, _] = try_find_computed_accumulator(pos); - - if (oldest_st->accumulator.computed[Perspective]) - { - // Only update current position accumulator to minimize work. - StateInfo* states_to_update[2] = { pos.state(), nullptr }; - update_accumulator_incremental(pos, oldest_st, states_to_update); - } - else - { - update_accumulator_refresh(pos); - } - } + template + void hint_common_access_for_perspective(const Position& pos) const { - template - void update_accumulator(const Position& pos) const { - - auto [oldest_st, next] = try_find_computed_accumulator(pos); - - if (oldest_st->accumulator.computed[Perspective]) - { - if (next == nullptr) - return; - - // Now update the accumulators listed in states_to_update[], where the last element is a sentinel. - // Currently we update 2 accumulators. - // 1. for the current position - // 2. the next accumulator after the computed one - // The heuristic may change in the future. - StateInfo *states_to_update[3] = - { next, next == pos.state() ? nullptr : pos.state(), nullptr }; - - update_accumulator_incremental(pos, oldest_st, states_to_update); - } - else - { - update_accumulator_refresh(pos); - } - } + // Works like update_accumulator, but performs less work. + // Updates ONLY the accumulator for pos. + + // Look for a usable accumulator of an earlier position. We keep track + // of the estimated gain in terms of features to be added/subtracted. + // Fast early exit. + if (pos.state()->accumulator.computed[Perspective]) return; + + auto [oldest_st, _] = try_find_computed_accumulator(pos); + + if (oldest_st->accumulator.computed[Perspective]) { + // Only update current position accumulator to minimize work. + StateInfo* states_to_update[2] = {pos.state(), nullptr}; + update_accumulator_incremental(pos, oldest_st, states_to_update); + } else { + update_accumulator_refresh(pos); + } + } + + template void update_accumulator(const Position& pos) const { + + auto [oldest_st, next] = try_find_computed_accumulator(pos); + + if (oldest_st->accumulator.computed[Perspective]) { + if (next == nullptr) return; + + // Now update the accumulators listed in states_to_update[], where the last element is a sentinel. + // Currently we update 2 accumulators. + // 1. for the current position + // 2. the next accumulator after the computed one + // The heuristic may change in the future. + StateInfo* states_to_update[3] = {next, next == pos.state() ? nullptr : pos.state(), + nullptr}; + + update_accumulator_incremental(pos, oldest_st, states_to_update); + } else { + update_accumulator_refresh(pos); + } + } - alignas(CacheLineSize) BiasType biases[HalfDimensions]; - alignas(CacheLineSize) WeightType weights[HalfDimensions * InputDimensions]; - alignas(CacheLineSize) PSQTWeightType psqtWeights[InputDimensions * PSQTBuckets]; - }; + alignas(CacheLineSize) BiasType biases[HalfDimensions]; + alignas(CacheLineSize) WeightType weights[HalfDimensions * InputDimensions]; + alignas(CacheLineSize) PSQTWeightType psqtWeights[InputDimensions * PSQTBuckets]; + }; } // namespace Stockfish::Eval::NNUE -#endif // #ifndef NNUE_FEATURE_TRANSFORMER_H_INCLUDED +#endif // #ifndef NNUE_FEATURE_TRANSFORMER_H_INCLUDED diff --git a/src/position.cpp b/src/position.cpp index 120677432b6..01dddad20e1 100644 --- a/src/position.cpp +++ b/src/position.cpp @@ -44,127 +44,120 @@ using std::string; namespace Stockfish { -namespace Zobrist { + namespace Zobrist { - Key psq[PIECE_NB][SQUARE_NB]; - Key enpassant[FILE_NB]; - Key castling[CASTLING_RIGHT_NB]; - Key side; -} + Key psq[PIECE_NB][SQUARE_NB]; + Key enpassant[FILE_NB]; + Key castling[CASTLING_RIGHT_NB]; + Key side; + } -namespace { + namespace { -constexpr std::string_view PieceToChar(" PNBRQK pnbrqk"); + constexpr std::string_view PieceToChar(" PNBRQK pnbrqk"); -constexpr Piece Pieces[] = { W_PAWN, W_KNIGHT, W_BISHOP, W_ROOK, W_QUEEN, W_KING, - B_PAWN, B_KNIGHT, B_BISHOP, B_ROOK, B_QUEEN, B_KING }; -} // namespace + constexpr Piece Pieces[] = {W_PAWN, W_KNIGHT, W_BISHOP, W_ROOK, W_QUEEN, W_KING, + B_PAWN, B_KNIGHT, B_BISHOP, B_ROOK, B_QUEEN, B_KING}; + } // namespace -/// operator<<(Position) returns an ASCII representation of the position + /// operator<<(Position) returns an ASCII representation of the position -std::ostream& operator<<(std::ostream& os, const Position& pos) { + std::ostream& operator<<(std::ostream& os, const Position& pos) { - os << "\n +---+---+---+---+---+---+---+---+\n"; + os << "\n +---+---+---+---+---+---+---+---+\n"; - for (Rank r = RANK_8; r >= RANK_1; --r) - { - for (File f = FILE_A; f <= FILE_H; ++f) - os << " | " << PieceToChar[pos.piece_on(make_square(f, r))]; + for (Rank r = RANK_8; r >= RANK_1; --r) { + for (File f = FILE_A; f <= FILE_H; ++f) + os << " | " << PieceToChar[pos.piece_on(make_square(f, r))]; - os << " | " << (1 + r) << "\n +---+---+---+---+---+---+---+---+\n"; - } + os << " | " << (1 + r) << "\n +---+---+---+---+---+---+---+---+\n"; + } - os << " a b c d e f g h\n" - << "\nFen: " << pos.fen() << "\nKey: " << std::hex << std::uppercase - << std::setfill('0') << std::setw(16) << pos.key() - << std::setfill(' ') << std::dec << "\nCheckers: "; + os << " a b c d e f g h\n" + << "\nFen: " << pos.fen() << "\nKey: " << std::hex << std::uppercase << std::setfill('0') + << std::setw(16) << pos.key() << std::setfill(' ') << std::dec << "\nCheckers: "; - for (Bitboard b = pos.checkers(); b; ) - os << UCI::square(pop_lsb(b)) << " "; + for (Bitboard b = pos.checkers(); b;) os << UCI::square(pop_lsb(b)) << " "; - if ( int(Tablebases::MaxCardinality) >= popcount(pos.pieces()) - && !pos.can_castle(ANY_CASTLING)) - { - StateInfo st; - ASSERT_ALIGNED(&st, Eval::NNUE::CacheLineSize); + if (int(Tablebases::MaxCardinality) >= popcount(pos.pieces()) && + !pos.can_castle(ANY_CASTLING)) { + StateInfo st; + ASSERT_ALIGNED(&st, Eval::NNUE::CacheLineSize); - Position p; - p.set(pos.fen(), pos.is_chess960(), &st, pos.this_thread()); - Tablebases::ProbeState s1, s2; - Tablebases::WDLScore wdl = Tablebases::probe_wdl(p, &s1); - int dtz = Tablebases::probe_dtz(p, &s2); - os << "\nTablebases WDL: " << std::setw(4) << wdl << " (" << s1 << ")" - << "\nTablebases DTZ: " << std::setw(4) << dtz << " (" << s2 << ")"; - } + Position p; + p.set(pos.fen(), pos.is_chess960(), &st, pos.this_thread()); + Tablebases::ProbeState s1, s2; + Tablebases::WDLScore wdl = Tablebases::probe_wdl(p, &s1); + int dtz = Tablebases::probe_dtz(p, &s2); + os << "\nTablebases WDL: " << std::setw(4) << wdl << " (" << s1 << ")" + << "\nTablebases DTZ: " << std::setw(4) << dtz << " (" << s2 << ")"; + } - return os; -} + return os; + } -// Marcel van Kervinck's cuckoo algorithm for fast detection of "upcoming repetition" -// situations. Description of the algorithm in the following paper: -// http://web.archive.org/web/20201107002606/https://marcelk.net/2013-04-06/paper/upcoming-rep-v2.pdf + // Marcel van Kervinck's cuckoo algorithm for fast detection of "upcoming repetition" + // situations. Description of the algorithm in the following paper: + // http://web.archive.org/web/20201107002606/https://marcelk.net/2013-04-06/paper/upcoming-rep-v2.pdf -// First and second hash functions for indexing the cuckoo tables -inline int H1(Key h) { return h & 0x1fff; } -inline int H2(Key h) { return (h >> 16) & 0x1fff; } + // First and second hash functions for indexing the cuckoo tables + inline int H1(Key h) { return h & 0x1fff; } + inline int H2(Key h) { return (h >> 16) & 0x1fff; } -// Cuckoo tables with Zobrist hashes of valid reversible moves, and the moves themselves -Key cuckoo[8192]; -Move cuckooMove[8192]; + // Cuckoo tables with Zobrist hashes of valid reversible moves, and the moves themselves + Key cuckoo[8192]; + Move cuckooMove[8192]; -/// Position::init() initializes at startup the various arrays used to compute hash keys + /// Position::init() initializes at startup the various arrays used to compute hash keys -void Position::init() { + void Position::init() { - PRNG rng(1070372); + PRNG rng(1070372); - for (Piece pc : Pieces) - for (Square s = SQ_A1; s <= SQ_H8; ++s) - Zobrist::psq[pc][s] = rng.rand(); + for (Piece pc : Pieces) + for (Square s = SQ_A1; s <= SQ_H8; ++s) Zobrist::psq[pc][s] = rng.rand(); - for (File f = FILE_A; f <= FILE_H; ++f) - Zobrist::enpassant[f] = rng.rand(); + for (File f = FILE_A; f <= FILE_H; ++f) Zobrist::enpassant[f] = rng.rand(); - for (int cr = NO_CASTLING; cr <= ANY_CASTLING; ++cr) - Zobrist::castling[cr] = rng.rand(); + for (int cr = NO_CASTLING; cr <= ANY_CASTLING; ++cr) + Zobrist::castling[cr] = rng.rand(); - Zobrist::side = rng.rand(); + Zobrist::side = rng.rand(); - // Prepare the cuckoo tables - std::memset(cuckoo, 0, sizeof(cuckoo)); - std::memset(cuckooMove, 0, sizeof(cuckooMove)); - [[maybe_unused]] int count = 0; - for (Piece pc : Pieces) - for (Square s1 = SQ_A1; s1 <= SQ_H8; ++s1) - for (Square s2 = Square(s1 + 1); s2 <= SQ_H8; ++s2) - if ((type_of(pc) != PAWN) && (attacks_bb(type_of(pc), s1, 0) & s2)) - { - Move move = make_move(s1, s2); - Key key = Zobrist::psq[pc][s1] ^ Zobrist::psq[pc][s2] ^ Zobrist::side; - int i = H1(key); - while (true) - { - std::swap(cuckoo[i], key); - std::swap(cuckooMove[i], move); - if (move == MOVE_NONE) // Arrived at empty slot? - break; - i = (i == H1(key)) ? H2(key) : H1(key); // Push victim to alternative slot - } - count++; - } - assert(count == 3668); -} + // Prepare the cuckoo tables + std::memset(cuckoo, 0, sizeof(cuckoo)); + std::memset(cuckooMove, 0, sizeof(cuckooMove)); + [[maybe_unused]] int count = 0; + for (Piece pc : Pieces) + for (Square s1 = SQ_A1; s1 <= SQ_H8; ++s1) + for (Square s2 = Square(s1 + 1); s2 <= SQ_H8; ++s2) + if ((type_of(pc) != PAWN) && (attacks_bb(type_of(pc), s1, 0) & s2)) { + Move move = make_move(s1, s2); + Key key = Zobrist::psq[pc][s1] ^ Zobrist::psq[pc][s2] ^ Zobrist::side; + int i = H1(key); + while (true) { + std::swap(cuckoo[i], key); + std::swap(cuckooMove[i], move); + if (move == MOVE_NONE) // Arrived at empty slot? + break; + i = (i == H1(key)) ? H2(key) : + H1(key); // Push victim to alternative slot + } + count++; + } + assert(count == 3668); + } -/// Position::set() initializes the position object with the given FEN string. -/// This function is not very robust - make sure that input FENs are correct, -/// this is assumed to be the responsibility of the GUI. + /// Position::set() initializes the position object with the given FEN string. + /// This function is not very robust - make sure that input FENs are correct, + /// this is assumed to be the responsibility of the GUI. -Position& Position::set(const string& fenStr, bool isChess960, StateInfo* si, Thread* th) { -/* + Position& Position::set(const string& fenStr, bool isChess960, StateInfo* si, Thread* th) { + /* A FEN string defines a particular position using only the ASCII character set. A FEN string contains six fields separated by a space. The fields are: @@ -199,1129 +192,1034 @@ Position& Position::set(const string& fenStr, bool isChess960, StateInfo* si, Th incremented after Black's move. */ - unsigned char col, row, token; - size_t idx; - Square sq = SQ_A8; - std::istringstream ss(fenStr); + unsigned char col, row, token; + size_t idx; + Square sq = SQ_A8; + std::istringstream ss(fenStr); + + std::memset(this, 0, sizeof(Position)); + std::memset(si, 0, sizeof(StateInfo)); + st = si; - std::memset(this, 0, sizeof(Position)); - std::memset(si, 0, sizeof(StateInfo)); - st = si; + ss >> std::noskipws; - ss >> std::noskipws; + // 1. Piece placement + while ((ss >> token) && !isspace(token)) { + if (isdigit(token)) + sq += (token - '0') * EAST; // Advance the given number of files - // 1. Piece placement - while ((ss >> token) && !isspace(token)) - { - if (isdigit(token)) - sq += (token - '0') * EAST; // Advance the given number of files + else if (token == '/') + sq += 2 * SOUTH; - else if (token == '/') - sq += 2 * SOUTH; + else if ((idx = PieceToChar.find(token)) != string::npos) { + put_piece(Piece(idx), sq); + ++sq; + } + } - else if ((idx = PieceToChar.find(token)) != string::npos) { - put_piece(Piece(idx), sq); - ++sq; - } - } + // 2. Active color + ss >> token; + sideToMove = (token == 'w' ? WHITE : BLACK); + ss >> token; - // 2. Active color - ss >> token; - sideToMove = (token == 'w' ? WHITE : BLACK); - ss >> token; + // 3. Castling availability. Compatible with 3 standards: Normal FEN standard, + // Shredder-FEN that uses the letters of the columns on which the rooks began + // the game instead of KQkq and also X-FEN standard that, in case of Chess960, + // if an inner rook is associated with the castling right, the castling tag is + // replaced by the file letter of the involved rook, as for the Shredder-FEN. + while ((ss >> token) && !isspace(token)) { + Square rsq; + Color c = islower(token) ? BLACK : WHITE; + Piece rook = make_piece(c, ROOK); - // 3. Castling availability. Compatible with 3 standards: Normal FEN standard, - // Shredder-FEN that uses the letters of the columns on which the rooks began - // the game instead of KQkq and also X-FEN standard that, in case of Chess960, - // if an inner rook is associated with the castling right, the castling tag is - // replaced by the file letter of the involved rook, as for the Shredder-FEN. - while ((ss >> token) && !isspace(token)) - { - Square rsq; - Color c = islower(token) ? BLACK : WHITE; - Piece rook = make_piece(c, ROOK); + token = char(toupper(token)); - token = char(toupper(token)); + if (token == 'K') + for (rsq = relative_square(c, SQ_H1); piece_on(rsq) != rook; --rsq) {} - if (token == 'K') - for (rsq = relative_square(c, SQ_H1); piece_on(rsq) != rook; --rsq) {} + else if (token == 'Q') + for (rsq = relative_square(c, SQ_A1); piece_on(rsq) != rook; ++rsq) {} - else if (token == 'Q') - for (rsq = relative_square(c, SQ_A1); piece_on(rsq) != rook; ++rsq) {} + else if (token >= 'A' && token <= 'H') + rsq = make_square(File(token - 'A'), relative_rank(c, RANK_1)); - else if (token >= 'A' && token <= 'H') - rsq = make_square(File(token - 'A'), relative_rank(c, RANK_1)); + else + continue; - else - continue; + set_castling_right(c, rsq); + } - set_castling_right(c, rsq); - } + // 4. En passant square. + // Ignore if square is invalid or not on side to move relative rank 6. + bool enpassant = false; - // 4. En passant square. - // Ignore if square is invalid or not on side to move relative rank 6. - bool enpassant = false; + if (((ss >> col) && (col >= 'a' && col <= 'h')) && + ((ss >> row) && (row == (sideToMove == WHITE ? '6' : '3')))) { + st->epSquare = make_square(File(col - 'a'), Rank(row - '1')); - if ( ((ss >> col) && (col >= 'a' && col <= 'h')) - && ((ss >> row) && (row == (sideToMove == WHITE ? '6' : '3')))) - { - st->epSquare = make_square(File(col - 'a'), Rank(row - '1')); + // En passant square will be considered only if + // a) side to move have a pawn threatening epSquare + // b) there is an enemy pawn in front of epSquare + // c) there is no piece on epSquare or behind epSquare + enpassant = pawn_attacks_bb(~sideToMove, st->epSquare) & pieces(sideToMove, PAWN) && + (pieces(~sideToMove, PAWN) & (st->epSquare + pawn_push(~sideToMove))) && + !(pieces() & (st->epSquare | (st->epSquare + pawn_push(sideToMove)))); + } - // En passant square will be considered only if - // a) side to move have a pawn threatening epSquare - // b) there is an enemy pawn in front of epSquare - // c) there is no piece on epSquare or behind epSquare - enpassant = pawn_attacks_bb(~sideToMove, st->epSquare) & pieces(sideToMove, PAWN) - && (pieces(~sideToMove, PAWN) & (st->epSquare + pawn_push(~sideToMove))) - && !(pieces() & (st->epSquare | (st->epSquare + pawn_push(sideToMove)))); - } + if (!enpassant) st->epSquare = SQ_NONE; - if (!enpassant) - st->epSquare = SQ_NONE; + // 5-6. Halfmove clock and fullmove number + ss >> std::skipws >> st->rule50 >> gamePly; - // 5-6. Halfmove clock and fullmove number - ss >> std::skipws >> st->rule50 >> gamePly; + // Convert from fullmove starting from 1 to gamePly starting from 0, + // handle also common incorrect FEN with fullmove = 0. + gamePly = std::max(2 * (gamePly - 1), 0) + (sideToMove == BLACK); + + chess960 = isChess960; + thisThread = th; + set_state(); + + assert(pos_is_ok()); + + return *this; + } - // Convert from fullmove starting from 1 to gamePly starting from 0, - // handle also common incorrect FEN with fullmove = 0. - gamePly = std::max(2 * (gamePly - 1), 0) + (sideToMove == BLACK); - chess960 = isChess960; - thisThread = th; - set_state(); + /// Position::set_castling_right() is a helper function used to set castling + /// rights given the corresponding color and the rook starting square. - assert(pos_is_ok()); + void Position::set_castling_right(Color c, Square rfrom) { - return *this; -} + Square kfrom = square(c); + CastlingRights cr = c & (kfrom < rfrom ? KING_SIDE : QUEEN_SIDE); + st->castlingRights |= cr; + castlingRightsMask[kfrom] |= cr; + castlingRightsMask[rfrom] |= cr; + castlingRookSquare[cr] = rfrom; -/// Position::set_castling_right() is a helper function used to set castling -/// rights given the corresponding color and the rook starting square. + Square kto = relative_square(c, cr & KING_SIDE ? SQ_G1 : SQ_C1); + Square rto = relative_square(c, cr & KING_SIDE ? SQ_F1 : SQ_D1); -void Position::set_castling_right(Color c, Square rfrom) { + castlingPath[cr] = (between_bb(rfrom, rto) | between_bb(kfrom, kto)) & ~(kfrom | rfrom); + } - Square kfrom = square(c); - CastlingRights cr = c & (kfrom < rfrom ? KING_SIDE: QUEEN_SIDE); - st->castlingRights |= cr; - castlingRightsMask[kfrom] |= cr; - castlingRightsMask[rfrom] |= cr; - castlingRookSquare[cr] = rfrom; + /// Position::set_check_info() sets king attacks to detect if a move gives check - Square kto = relative_square(c, cr & KING_SIDE ? SQ_G1 : SQ_C1); - Square rto = relative_square(c, cr & KING_SIDE ? SQ_F1 : SQ_D1); + void Position::set_check_info() const { - castlingPath[cr] = (between_bb(rfrom, rto) | between_bb(kfrom, kto)) - & ~(kfrom | rfrom); -} + update_slider_blockers(WHITE); + update_slider_blockers(BLACK); + Square ksq = square(~sideToMove); -/// Position::set_check_info() sets king attacks to detect if a move gives check + st->checkSquares[PAWN] = pawn_attacks_bb(~sideToMove, ksq); + st->checkSquares[KNIGHT] = attacks_bb(ksq); + st->checkSquares[BISHOP] = attacks_bb(ksq, pieces()); + st->checkSquares[ROOK] = attacks_bb(ksq, pieces()); + st->checkSquares[QUEEN] = st->checkSquares[BISHOP] | st->checkSquares[ROOK]; + st->checkSquares[KING] = 0; + } -void Position::set_check_info() const { - update_slider_blockers(WHITE); - update_slider_blockers(BLACK); + /// Position::set_state() computes the hash keys of the position, and other + /// data that once computed is updated incrementally as moves are made. + /// The function is only used when a new position is set up - Square ksq = square(~sideToMove); + void Position::set_state() const { - st->checkSquares[PAWN] = pawn_attacks_bb(~sideToMove, ksq); - st->checkSquares[KNIGHT] = attacks_bb(ksq); - st->checkSquares[BISHOP] = attacks_bb(ksq, pieces()); - st->checkSquares[ROOK] = attacks_bb(ksq, pieces()); - st->checkSquares[QUEEN] = st->checkSquares[BISHOP] | st->checkSquares[ROOK]; - st->checkSquares[KING] = 0; -} + st->key = st->materialKey = 0; + st->nonPawnMaterial[WHITE] = st->nonPawnMaterial[BLACK] = VALUE_ZERO; + st->checkersBB = attackers_to(square(sideToMove)) & pieces(~sideToMove); + set_check_info(); -/// Position::set_state() computes the hash keys of the position, and other -/// data that once computed is updated incrementally as moves are made. -/// The function is only used when a new position is set up + for (Bitboard b = pieces(); b;) { + Square s = pop_lsb(b); + Piece pc = piece_on(s); + st->key ^= Zobrist::psq[pc][s]; -void Position::set_state() const { + if (type_of(pc) != KING && type_of(pc) != PAWN) + st->nonPawnMaterial[color_of(pc)] += PieceValue[pc]; + } - st->key = st->materialKey = 0; - st->nonPawnMaterial[WHITE] = st->nonPawnMaterial[BLACK] = VALUE_ZERO; - st->checkersBB = attackers_to(square(sideToMove)) & pieces(~sideToMove); + if (st->epSquare != SQ_NONE) st->key ^= Zobrist::enpassant[file_of(st->epSquare)]; - set_check_info(); + if (sideToMove == BLACK) st->key ^= Zobrist::side; - for (Bitboard b = pieces(); b; ) - { - Square s = pop_lsb(b); - Piece pc = piece_on(s); - st->key ^= Zobrist::psq[pc][s]; + st->key ^= Zobrist::castling[st->castlingRights]; - if (type_of(pc) != KING && type_of(pc) != PAWN) - st->nonPawnMaterial[color_of(pc)] += PieceValue[pc]; - } + for (Piece pc : Pieces) + for (int cnt = 0; cnt < pieceCount[pc]; ++cnt) st->materialKey ^= Zobrist::psq[pc][cnt]; + } - if (st->epSquare != SQ_NONE) - st->key ^= Zobrist::enpassant[file_of(st->epSquare)]; - if (sideToMove == BLACK) - st->key ^= Zobrist::side; + /// Position::set() is an overload to initialize the position object with + /// the given endgame code string like "KBPKN". It is mainly a helper to + /// get the material key out of an endgame code. - st->key ^= Zobrist::castling[st->castlingRights]; + Position& Position::set(const string& code, Color c, StateInfo* si) { - for (Piece pc : Pieces) - for (int cnt = 0; cnt < pieceCount[pc]; ++cnt) - st->materialKey ^= Zobrist::psq[pc][cnt]; -} + assert(code[0] == 'K'); + string sides[] = {code.substr(code.find('K', 1)), // Weak + code.substr(0, std::min(code.find('v'), code.find('K', 1)))}; // Strong -/// Position::set() is an overload to initialize the position object with -/// the given endgame code string like "KBPKN". It is mainly a helper to -/// get the material key out of an endgame code. + assert(sides[0].length() > 0 && sides[0].length() < 8); + assert(sides[1].length() > 0 && sides[1].length() < 8); -Position& Position::set(const string& code, Color c, StateInfo* si) { + std::transform(sides[c].begin(), sides[c].end(), sides[c].begin(), tolower); - assert(code[0] == 'K'); + string fenStr = "8/" + sides[0] + char(8 - sides[0].length() + '0') + "/8/8/8/8/" + + sides[1] + char(8 - sides[1].length() + '0') + "/8 w - - 0 10"; - string sides[] = { code.substr(code.find('K', 1)), // Weak - code.substr(0, std::min(code.find('v'), code.find('K', 1))) }; // Strong + return set(fenStr, false, si, nullptr); + } - assert(sides[0].length() > 0 && sides[0].length() < 8); - assert(sides[1].length() > 0 && sides[1].length() < 8); - std::transform(sides[c].begin(), sides[c].end(), sides[c].begin(), tolower); + /// Position::fen() returns a FEN representation of the position. In case of + /// Chess960 the Shredder-FEN notation is used. This is mainly a debugging function. - string fenStr = "8/" + sides[0] + char(8 - sides[0].length() + '0') + "/8/8/8/8/" - + sides[1] + char(8 - sides[1].length() + '0') + "/8 w - - 0 10"; + string Position::fen() const { - return set(fenStr, false, si, nullptr); -} + int emptyCnt; + std::ostringstream ss; + for (Rank r = RANK_8; r >= RANK_1; --r) { + for (File f = FILE_A; f <= FILE_H; ++f) { + for (emptyCnt = 0; f <= FILE_H && empty(make_square(f, r)); ++f) ++emptyCnt; -/// Position::fen() returns a FEN representation of the position. In case of -/// Chess960 the Shredder-FEN notation is used. This is mainly a debugging function. + if (emptyCnt) ss << emptyCnt; -string Position::fen() const { + if (f <= FILE_H) ss << PieceToChar[piece_on(make_square(f, r))]; + } - int emptyCnt; - std::ostringstream ss; + if (r > RANK_1) ss << '/'; + } - for (Rank r = RANK_8; r >= RANK_1; --r) - { - for (File f = FILE_A; f <= FILE_H; ++f) - { - for (emptyCnt = 0; f <= FILE_H && empty(make_square(f, r)); ++f) - ++emptyCnt; + ss << (sideToMove == WHITE ? " w " : " b "); - if (emptyCnt) - ss << emptyCnt; + if (can_castle(WHITE_OO)) + ss << (chess960 ? char('A' + file_of(castling_rook_square(WHITE_OO))) : 'K'); - if (f <= FILE_H) - ss << PieceToChar[piece_on(make_square(f, r))]; - } + if (can_castle(WHITE_OOO)) + ss << (chess960 ? char('A' + file_of(castling_rook_square(WHITE_OOO))) : 'Q'); - if (r > RANK_1) - ss << '/'; - } + if (can_castle(BLACK_OO)) + ss << (chess960 ? char('a' + file_of(castling_rook_square(BLACK_OO))) : 'k'); - ss << (sideToMove == WHITE ? " w " : " b "); + if (can_castle(BLACK_OOO)) + ss << (chess960 ? char('a' + file_of(castling_rook_square(BLACK_OOO))) : 'q'); - if (can_castle(WHITE_OO)) - ss << (chess960 ? char('A' + file_of(castling_rook_square(WHITE_OO ))) : 'K'); + if (!can_castle(ANY_CASTLING)) ss << '-'; - if (can_castle(WHITE_OOO)) - ss << (chess960 ? char('A' + file_of(castling_rook_square(WHITE_OOO))) : 'Q'); + ss << (ep_square() == SQ_NONE ? " - " : " " + UCI::square(ep_square()) + " ") << st->rule50 + << " " << 1 + (gamePly - (sideToMove == BLACK)) / 2; - if (can_castle(BLACK_OO)) - ss << (chess960 ? char('a' + file_of(castling_rook_square(BLACK_OO ))) : 'k'); + return ss.str(); + } - if (can_castle(BLACK_OOO)) - ss << (chess960 ? char('a' + file_of(castling_rook_square(BLACK_OOO))) : 'q'); + /// update_slider_blockers() calculates st->blockersForKing[c] and st->pinners[~c], + /// which store respectively the pieces preventing king of color c from being in check + /// and the slider pieces of color ~c pinning pieces of color c to the king. + void Position::update_slider_blockers(Color c) const { - if (!can_castle(ANY_CASTLING)) - ss << '-'; + Square ksq = square(c); - ss << (ep_square() == SQ_NONE ? " - " : " " + UCI::square(ep_square()) + " ") - << st->rule50 << " " << 1 + (gamePly - (sideToMove == BLACK)) / 2; + st->blockersForKing[c] = 0; + st->pinners[~c] = 0; - return ss.str(); -} + // Snipers are sliders that attack 's' when a piece and other snipers are removed + Bitboard snipers = ((attacks_bb(ksq) & pieces(QUEEN, ROOK)) | + (attacks_bb(ksq) & pieces(QUEEN, BISHOP))) & + pieces(~c); + Bitboard occupancy = pieces() ^ snipers; -/// update_slider_blockers() calculates st->blockersForKing[c] and st->pinners[~c], -/// which store respectively the pieces preventing king of color c from being in check -/// and the slider pieces of color ~c pinning pieces of color c to the king. -void Position::update_slider_blockers(Color c) const { + while (snipers) { + Square sniperSq = pop_lsb(snipers); + Bitboard b = between_bb(ksq, sniperSq) & occupancy; - Square ksq = square(c); + if (b && !more_than_one(b)) { + st->blockersForKing[c] |= b; + if (b & pieces(c)) st->pinners[~c] |= sniperSq; + } + } + } - st->blockersForKing[c] = 0; - st->pinners[~c] = 0; - // Snipers are sliders that attack 's' when a piece and other snipers are removed - Bitboard snipers = ( (attacks_bb< ROOK>(ksq) & pieces(QUEEN, ROOK)) - | (attacks_bb(ksq) & pieces(QUEEN, BISHOP))) & pieces(~c); - Bitboard occupancy = pieces() ^ snipers; + /// Position::attackers_to() computes a bitboard of all pieces which attack a + /// given square. Slider attacks use the occupied bitboard to indicate occupancy. - while (snipers) - { - Square sniperSq = pop_lsb(snipers); - Bitboard b = between_bb(ksq, sniperSq) & occupancy; + Bitboard Position::attackers_to(Square s, Bitboard occupied) const { - if (b && !more_than_one(b)) - { - st->blockersForKing[c] |= b; - if (b & pieces(c)) - st->pinners[~c] |= sniperSq; + return (pawn_attacks_bb(BLACK, s) & pieces(WHITE, PAWN)) | + (pawn_attacks_bb(WHITE, s) & pieces(BLACK, PAWN)) | + (attacks_bb(s) & pieces(KNIGHT)) | + (attacks_bb(s, occupied) & pieces(ROOK, QUEEN)) | + (attacks_bb(s, occupied) & pieces(BISHOP, QUEEN)) | + (attacks_bb(s) & pieces(KING)); } - } -} -/// Position::attackers_to() computes a bitboard of all pieces which attack a -/// given square. Slider attacks use the occupied bitboard to indicate occupancy. + /// Position::legal() tests whether a pseudo-legal move is legal + + bool Position::legal(Move m) const { + + assert(is_ok(m)); + + Color us = sideToMove; + Square from = from_sq(m); + Square to = to_sq(m); -Bitboard Position::attackers_to(Square s, Bitboard occupied) const { + assert(color_of(moved_piece(m)) == us); + assert(piece_on(square(us)) == make_piece(us, KING)); - return (pawn_attacks_bb(BLACK, s) & pieces(WHITE, PAWN)) - | (pawn_attacks_bb(WHITE, s) & pieces(BLACK, PAWN)) - | (attacks_bb(s) & pieces(KNIGHT)) - | (attacks_bb< ROOK>(s, occupied) & pieces( ROOK, QUEEN)) - | (attacks_bb(s, occupied) & pieces(BISHOP, QUEEN)) - | (attacks_bb(s) & pieces(KING)); -} - - -/// Position::legal() tests whether a pseudo-legal move is legal - -bool Position::legal(Move m) const { - - assert(is_ok(m)); - - Color us = sideToMove; - Square from = from_sq(m); - Square to = to_sq(m); - - assert(color_of(moved_piece(m)) == us); - assert(piece_on(square(us)) == make_piece(us, KING)); - - // En passant captures are a tricky special case. Because they are rather - // uncommon, we do it simply by testing whether the king is attacked after - // the move is made. - if (type_of(m) == EN_PASSANT) - { - Square ksq = square(us); - Square capsq = to - pawn_push(us); - Bitboard occupied = (pieces() ^ from ^ capsq) | to; - - assert(to == ep_square()); - assert(moved_piece(m) == make_piece(us, PAWN)); - assert(piece_on(capsq) == make_piece(~us, PAWN)); - assert(piece_on(to) == NO_PIECE); - - return !(attacks_bb< ROOK>(ksq, occupied) & pieces(~us, QUEEN, ROOK)) - && !(attacks_bb(ksq, occupied) & pieces(~us, QUEEN, BISHOP)); - } - - // Castling moves generation does not check if the castling path is clear of - // enemy attacks, it is delayed at a later time: now! - if (type_of(m) == CASTLING) - { - // After castling, the rook and king final positions are the same in - // Chess960 as they would be in standard chess. - to = relative_square(us, to > from ? SQ_G1 : SQ_C1); - Direction step = to > from ? WEST : EAST; - - for (Square s = to; s != from; s += step) - if (attackers_to(s) & pieces(~us)) - return false; - - // In case of Chess960, verify if the Rook blocks some checks - // For instance an enemy queen in SQ_A1 when castling rook is in SQ_B1. - return !chess960 || !(blockers_for_king(us) & to_sq(m)); - } - - // If the moving piece is a king, check whether the destination square is - // attacked by the opponent. - if (type_of(piece_on(from)) == KING) - return !(attackers_to(to, pieces() ^ from) & pieces(~us)); - - // A non-king move is legal if and only if it is not pinned or it - // is moving along the ray towards or away from the king. - return !(blockers_for_king(us) & from) - || aligned(from, to, square(us)); -} - - -/// Position::pseudo_legal() takes a random move and tests whether the move is -/// pseudo legal. It is used to validate moves from TT that can be corrupted -/// due to SMP concurrent access or hash position key aliasing. - -bool Position::pseudo_legal(const Move m) const { - - Color us = sideToMove; - Square from = from_sq(m); - Square to = to_sq(m); - Piece pc = moved_piece(m); - - // Use a slower but simpler function for uncommon cases - // yet we skip the legality check of MoveList(). - if (type_of(m) != NORMAL) - return checkers() ? MoveList< EVASIONS>(*this).contains(m) - : MoveList(*this).contains(m); - - // Is not a promotion, so promotion piece must be empty - assert(promotion_type(m) - KNIGHT == NO_PIECE_TYPE); - - // If the 'from' square is not occupied by a piece belonging to the side to - // move, the move is obviously not legal. - if (pc == NO_PIECE || color_of(pc) != us) - return false; - - // The destination square cannot be occupied by a friendly piece - if (pieces(us) & to) - return false; - - // Handle the special case of a pawn move - if (type_of(pc) == PAWN) - { - // We have already handled promotion moves, so destination - // cannot be on the 8th/1st rank. - if ((Rank8BB | Rank1BB) & to) - return false; - - if ( !(pawn_attacks_bb(us, from) & pieces(~us) & to) // Not a capture - && !((from + pawn_push(us) == to) && empty(to)) // Not a single push - && !( (from + 2 * pawn_push(us) == to) // Not a double push - && (relative_rank(us, from) == RANK_2) - && empty(to) - && empty(to - pawn_push(us)))) - return false; - } - else if (!(attacks_bb(type_of(pc), from, pieces()) & to)) - return false; - - // Evasions generator already takes care to avoid some kind of illegal moves - // and legal() relies on this. We therefore have to take care that the same - // kind of moves are filtered out here. - if (checkers()) - { - if (type_of(pc) != KING) - { - // Double check? In this case a king move is required - if (more_than_one(checkers())) - return false; - - // Our move must be a blocking interposition or a capture of the checking piece - if (!(between_bb(square(us), lsb(checkers())) & to)) - return false; - } - // In case of king moves under check we have to remove king so as to catch - // invalid moves like b1a1 when opposite queen is on c1. - else if (attackers_to(to, pieces() ^ from) & pieces(~us)) - return false; - } - - return true; -} - - -/// Position::gives_check() tests whether a pseudo-legal move gives a check - -bool Position::gives_check(Move m) const { - - assert(is_ok(m)); - assert(color_of(moved_piece(m)) == sideToMove); - - Square from = from_sq(m); - Square to = to_sq(m); - - // Is there a direct check? - if (check_squares(type_of(piece_on(from))) & to) - return true; - - // Is there a discovered check? - if (blockers_for_king(~sideToMove) & from) - return !aligned(from, to, square(~sideToMove)) - || type_of(m) == CASTLING; - - switch (type_of(m)) - { - case NORMAL: - return false; - - case PROMOTION: - return attacks_bb(promotion_type(m), to, pieces() ^ from) & square(~sideToMove); - - // En passant capture with check? We have already handled the case - // of direct checks and ordinary discovered check, so the only case we - // need to handle is the unusual case of a discovered check through - // the captured pawn. - case EN_PASSANT: - { - Square capsq = make_square(file_of(to), rank_of(from)); - Bitboard b = (pieces() ^ from ^ capsq) | to; - - return (attacks_bb< ROOK>(square(~sideToMove), b) & pieces(sideToMove, QUEEN, ROOK)) - | (attacks_bb(square(~sideToMove), b) & pieces(sideToMove, QUEEN, BISHOP)); - } - default: //CASTLING - { - // Castling is encoded as 'king captures the rook' - Square rto = relative_square(sideToMove, to > from ? SQ_F1 : SQ_D1); - - return check_squares(ROOK) & rto; - } - } -} - - -/// Position::do_move() makes a move, and saves all information necessary -/// to a StateInfo object. The move is assumed to be legal. Pseudo-legal -/// moves should be filtered out before this function is called. - -void Position::do_move(Move m, StateInfo& newSt, bool givesCheck) { - - assert(is_ok(m)); - assert(&newSt != st); - - thisThread->nodes.fetch_add(1, std::memory_order_relaxed); - Key k = st->key ^ Zobrist::side; - - // Copy some fields of the old state to our new StateInfo object except the - // ones which are going to be recalculated from scratch anyway and then switch - // our state pointer to point to the new (ready to be updated) state. - std::memcpy(&newSt, st, offsetof(StateInfo, key)); - newSt.previous = st; - st = &newSt; - - // Increment ply counters. In particular, rule50 will be reset to zero later on - // in case of a capture or a pawn move. - ++gamePly; - ++st->rule50; - ++st->pliesFromNull; - - // Used by NNUE - st->accumulator.computed[WHITE] = false; - st->accumulator.computed[BLACK] = false; - auto& dp = st->dirtyPiece; - dp.dirty_num = 1; - - Color us = sideToMove; - Color them = ~us; - Square from = from_sq(m); - Square to = to_sq(m); - Piece pc = piece_on(from); - Piece captured = type_of(m) == EN_PASSANT ? make_piece(them, PAWN) : piece_on(to); - - assert(color_of(pc) == us); - assert(captured == NO_PIECE || color_of(captured) == (type_of(m) != CASTLING ? them : us)); - assert(type_of(captured) != KING); - - if (type_of(m) == CASTLING) - { - assert(pc == make_piece(us, KING)); - assert(captured == make_piece(us, ROOK)); - - Square rfrom, rto; - do_castling(us, from, to, rfrom, rto); - - k ^= Zobrist::psq[captured][rfrom] ^ Zobrist::psq[captured][rto]; - captured = NO_PIECE; - } - - if (captured) - { - Square capsq = to; - - // If the captured piece is a pawn, update pawn hash key, otherwise - // update non-pawn material. - if (type_of(captured) == PAWN) - { - if (type_of(m) == EN_PASSANT) - { - capsq -= pawn_push(us); - - assert(pc == make_piece(us, PAWN)); - assert(to == st->epSquare); - assert(relative_rank(us, to) == RANK_6); - assert(piece_on(to) == NO_PIECE); - assert(piece_on(capsq) == make_piece(them, PAWN)); - } - } - else - st->nonPawnMaterial[them] -= PieceValue[captured]; - - dp.dirty_num = 2; // 1 piece moved, 1 piece captured - dp.piece[1] = captured; - dp.from[1] = capsq; - dp.to[1] = SQ_NONE; - - // Update board and piece lists - remove_piece(capsq); - - // Update material hash key and prefetch access to materialTable - k ^= Zobrist::psq[captured][capsq]; - st->materialKey ^= Zobrist::psq[captured][pieceCount[captured]]; - - // Reset rule 50 counter - st->rule50 = 0; - } - - // Update hash key - k ^= Zobrist::psq[pc][from] ^ Zobrist::psq[pc][to]; - - // Reset en passant square - if (st->epSquare != SQ_NONE) - { - k ^= Zobrist::enpassant[file_of(st->epSquare)]; - st->epSquare = SQ_NONE; - } - - // Update castling rights if needed - if (st->castlingRights && (castlingRightsMask[from] | castlingRightsMask[to])) - { - k ^= Zobrist::castling[st->castlingRights]; - st->castlingRights &= ~(castlingRightsMask[from] | castlingRightsMask[to]); - k ^= Zobrist::castling[st->castlingRights]; - } - - // Move the piece. The tricky Chess960 castling is handled earlier - if (type_of(m) != CASTLING) - { - dp.piece[0] = pc; - dp.from[0] = from; - dp.to[0] = to; - - move_piece(from, to); - } - - // If the moving piece is a pawn do some special extra work - if (type_of(pc) == PAWN) - { - // Set en passant square if the moved pawn can be captured - if ( (int(to) ^ int(from)) == 16 - && (pawn_attacks_bb(us, to - pawn_push(us)) & pieces(them, PAWN))) - { - st->epSquare = to - pawn_push(us); - k ^= Zobrist::enpassant[file_of(st->epSquare)]; - } - - else if (type_of(m) == PROMOTION) - { - Piece promotion = make_piece(us, promotion_type(m)); - - assert(relative_rank(us, to) == RANK_8); - assert(type_of(promotion) >= KNIGHT && type_of(promotion) <= QUEEN); - - remove_piece(to); - put_piece(promotion, to); - - // Promoting pawn to SQ_NONE, promoted piece from SQ_NONE - dp.to[0] = SQ_NONE; - dp.piece[dp.dirty_num] = promotion; - dp.from[dp.dirty_num] = SQ_NONE; - dp.to[dp.dirty_num] = to; - dp.dirty_num++; - - // Update hash keys - k ^= Zobrist::psq[pc][to] ^ Zobrist::psq[promotion][to]; - st->materialKey ^= Zobrist::psq[promotion][pieceCount[promotion]-1] - ^ Zobrist::psq[pc][pieceCount[pc]]; - - // Update material - st->nonPawnMaterial[us] += PieceValue[promotion]; - } - - // Reset rule 50 draw counter - st->rule50 = 0; - } - - // Set capture piece - st->capturedPiece = captured; - - // Update the key with the final value - st->key = k; - - // Calculate checkers bitboard (if move gives check) - st->checkersBB = givesCheck ? attackers_to(square(them)) & pieces(us) : 0; - - sideToMove = ~sideToMove; - - // Update king attacks used for fast check detection - set_check_info(); - - // Calculate the repetition info. It is the ply distance from the previous - // occurrence of the same position, negative in the 3-fold case, or zero - // if the position was not repeated. - st->repetition = 0; - int end = std::min(st->rule50, st->pliesFromNull); - if (end >= 4) - { - StateInfo* stp = st->previous->previous; - for (int i = 4; i <= end; i += 2) - { - stp = stp->previous->previous; - if (stp->key == st->key) - { - st->repetition = stp->repetition ? -i : i; - break; - } - } - } - - assert(pos_is_ok()); -} - - -/// Position::undo_move() unmakes a move. When it returns, the position should -/// be restored to exactly the same state as before the move was made. - -void Position::undo_move(Move m) { - - assert(is_ok(m)); - - sideToMove = ~sideToMove; - - Color us = sideToMove; - Square from = from_sq(m); - Square to = to_sq(m); - Piece pc = piece_on(to); - - assert(empty(from) || type_of(m) == CASTLING); - assert(type_of(st->capturedPiece) != KING); - - if (type_of(m) == PROMOTION) - { - assert(relative_rank(us, to) == RANK_8); - assert(type_of(pc) == promotion_type(m)); - assert(type_of(pc) >= KNIGHT && type_of(pc) <= QUEEN); - - remove_piece(to); - pc = make_piece(us, PAWN); - put_piece(pc, to); - } - - if (type_of(m) == CASTLING) - { - Square rfrom, rto; - do_castling(us, from, to, rfrom, rto); - } - else - { - move_piece(to, from); // Put the piece back at the source square + // En passant captures are a tricky special case. Because they are rather + // uncommon, we do it simply by testing whether the king is attacked after + // the move is made. + if (type_of(m) == EN_PASSANT) { + Square ksq = square(us); + Square capsq = to - pawn_push(us); + Bitboard occupied = (pieces() ^ from ^ capsq) | to; - if (st->capturedPiece) - { - Square capsq = to; + assert(to == ep_square()); + assert(moved_piece(m) == make_piece(us, PAWN)); + assert(piece_on(capsq) == make_piece(~us, PAWN)); + assert(piece_on(to) == NO_PIECE); - if (type_of(m) == EN_PASSANT) - { - capsq -= pawn_push(us); + return !(attacks_bb(ksq, occupied) & pieces(~us, QUEEN, ROOK)) && + !(attacks_bb(ksq, occupied) & pieces(~us, QUEEN, BISHOP)); + } - assert(type_of(pc) == PAWN); - assert(to == st->previous->epSquare); - assert(relative_rank(us, to) == RANK_6); - assert(piece_on(capsq) == NO_PIECE); - assert(st->capturedPiece == make_piece(~us, PAWN)); - } + // Castling moves generation does not check if the castling path is clear of + // enemy attacks, it is delayed at a later time: now! + if (type_of(m) == CASTLING) { + // After castling, the rook and king final positions are the same in + // Chess960 as they would be in standard chess. + to = relative_square(us, to > from ? SQ_G1 : SQ_C1); + Direction step = to > from ? WEST : EAST; - put_piece(st->capturedPiece, capsq); // Restore the captured piece - } - } + for (Square s = to; s != from; s += step) + if (attackers_to(s) & pieces(~us)) return false; - // Finally point our state pointer back to the previous state - st = st->previous; - --gamePly; + // In case of Chess960, verify if the Rook blocks some checks + // For instance an enemy queen in SQ_A1 when castling rook is in SQ_B1. + return !chess960 || !(blockers_for_king(us) & to_sq(m)); + } - assert(pos_is_ok()); -} + // If the moving piece is a king, check whether the destination square is + // attacked by the opponent. + if (type_of(piece_on(from)) == KING) + return !(attackers_to(to, pieces() ^ from) & pieces(~us)); + // A non-king move is legal if and only if it is not pinned or it + // is moving along the ray towards or away from the king. + return !(blockers_for_king(us) & from) || aligned(from, to, square(us)); + } -/// Position::do_castling() is a helper used to do/undo a castling move. This -/// is a bit tricky in Chess960 where from/to squares can overlap. -template -void Position::do_castling(Color us, Square from, Square& to, Square& rfrom, Square& rto) { - bool kingSide = to > from; - rfrom = to; // Castling is encoded as "king captures friendly rook" - rto = relative_square(us, kingSide ? SQ_F1 : SQ_D1); - to = relative_square(us, kingSide ? SQ_G1 : SQ_C1); + /// Position::pseudo_legal() takes a random move and tests whether the move is + /// pseudo legal. It is used to validate moves from TT that can be corrupted + /// due to SMP concurrent access or hash position key aliasing. + + bool Position::pseudo_legal(const Move m) const { + + Color us = sideToMove; + Square from = from_sq(m); + Square to = to_sq(m); + Piece pc = moved_piece(m); + + // Use a slower but simpler function for uncommon cases + // yet we skip the legality check of MoveList(). + if (type_of(m) != NORMAL) + return checkers() ? MoveList(*this).contains(m) : + MoveList(*this).contains(m); + + // Is not a promotion, so promotion piece must be empty + assert(promotion_type(m) - KNIGHT == NO_PIECE_TYPE); + + // If the 'from' square is not occupied by a piece belonging to the side to + // move, the move is obviously not legal. + if (pc == NO_PIECE || color_of(pc) != us) return false; + + // The destination square cannot be occupied by a friendly piece + if (pieces(us) & to) return false; + + // Handle the special case of a pawn move + if (type_of(pc) == PAWN) { + // We have already handled promotion moves, so destination + // cannot be on the 8th/1st rank. + if ((Rank8BB | Rank1BB) & to) return false; + + if (!(pawn_attacks_bb(us, from) & pieces(~us) & to) // Not a capture + && !((from + pawn_push(us) == to) && empty(to)) // Not a single push + && + !((from + 2 * pawn_push(us) == to) // Not a double push + && (relative_rank(us, from) == RANK_2) && empty(to) && empty(to - pawn_push(us)))) + return false; + } else if (!(attacks_bb(type_of(pc), from, pieces()) & to)) + return false; + + // Evasions generator already takes care to avoid some kind of illegal moves + // and legal() relies on this. We therefore have to take care that the same + // kind of moves are filtered out here. + if (checkers()) { + if (type_of(pc) != KING) { + // Double check? In this case a king move is required + if (more_than_one(checkers())) return false; + + // Our move must be a blocking interposition or a capture of the checking piece + if (!(between_bb(square(us), lsb(checkers())) & to)) return false; + } + // In case of king moves under check we have to remove king so as to catch + // invalid moves like b1a1 when opposite queen is on c1. + else if (attackers_to(to, pieces() ^ from) & pieces(~us)) + return false; + } + + return true; + } - if (Do) - { - auto& dp = st->dirtyPiece; - dp.piece[0] = make_piece(us, KING); - dp.from[0] = from; - dp.to[0] = to; - dp.piece[1] = make_piece(us, ROOK); - dp.from[1] = rfrom; - dp.to[1] = rto; - dp.dirty_num = 2; - } - // Remove both pieces first since squares could overlap in Chess960 - remove_piece(Do ? from : to); - remove_piece(Do ? rfrom : rto); - board[Do ? from : to] = board[Do ? rfrom : rto] = NO_PIECE; // Since remove_piece doesn't do this for us - put_piece(make_piece(us, KING), Do ? to : from); - put_piece(make_piece(us, ROOK), Do ? rto : rfrom); -} + /// Position::gives_check() tests whether a pseudo-legal move gives a check + bool Position::gives_check(Move m) const { -/// Position::do_null_move() is used to do a "null move": it flips -/// the side to move without executing any move on the board. + assert(is_ok(m)); + assert(color_of(moved_piece(m)) == sideToMove); -void Position::do_null_move(StateInfo& newSt) { + Square from = from_sq(m); + Square to = to_sq(m); - assert(!checkers()); - assert(&newSt != st); + // Is there a direct check? + if (check_squares(type_of(piece_on(from))) & to) return true; - std::memcpy(&newSt, st, offsetof(StateInfo, accumulator)); + // Is there a discovered check? + if (blockers_for_king(~sideToMove) & from) + return !aligned(from, to, square(~sideToMove)) || type_of(m) == CASTLING; - newSt.previous = st; - st = &newSt; + switch (type_of(m)) { + case NORMAL : return false; - st->dirtyPiece.dirty_num = 0; - st->dirtyPiece.piece[0] = NO_PIECE; // Avoid checks in UpdateAccumulator() - st->accumulator.computed[WHITE] = false; - st->accumulator.computed[BLACK] = false; + case PROMOTION : + return attacks_bb(promotion_type(m), to, pieces() ^ from) & square(~sideToMove); - if (st->epSquare != SQ_NONE) - { - st->key ^= Zobrist::enpassant[file_of(st->epSquare)]; - st->epSquare = SQ_NONE; - } + // En passant capture with check? We have already handled the case + // of direct checks and ordinary discovered check, so the only case we + // need to handle is the unusual case of a discovered check through + // the captured pawn. + case EN_PASSANT : { + Square capsq = make_square(file_of(to), rank_of(from)); + Bitboard b = (pieces() ^ from ^ capsq) | to; - st->key ^= Zobrist::side; - ++st->rule50; - prefetch(TT.first_entry(key())); + return (attacks_bb(square(~sideToMove), b) & + pieces(sideToMove, QUEEN, ROOK)) | + (attacks_bb(square(~sideToMove), b) & + pieces(sideToMove, QUEEN, BISHOP)); + } + default : //CASTLING + { + // Castling is encoded as 'king captures the rook' + Square rto = relative_square(sideToMove, to > from ? SQ_F1 : SQ_D1); - st->pliesFromNull = 0; + return check_squares(ROOK) & rto; + } + } + } - sideToMove = ~sideToMove; - set_check_info(); + /// Position::do_move() makes a move, and saves all information necessary + /// to a StateInfo object. The move is assumed to be legal. Pseudo-legal + /// moves should be filtered out before this function is called. + + void Position::do_move(Move m, StateInfo& newSt, bool givesCheck) { + + assert(is_ok(m)); + assert(&newSt != st); + + thisThread->nodes.fetch_add(1, std::memory_order_relaxed); + Key k = st->key ^ Zobrist::side; + + // Copy some fields of the old state to our new StateInfo object except the + // ones which are going to be recalculated from scratch anyway and then switch + // our state pointer to point to the new (ready to be updated) state. + std::memcpy(&newSt, st, offsetof(StateInfo, key)); + newSt.previous = st; + st = &newSt; + + // Increment ply counters. In particular, rule50 will be reset to zero later on + // in case of a capture or a pawn move. + ++gamePly; + ++st->rule50; + ++st->pliesFromNull; + + // Used by NNUE + st->accumulator.computed[WHITE] = false; + st->accumulator.computed[BLACK] = false; + auto& dp = st->dirtyPiece; + dp.dirty_num = 1; + + Color us = sideToMove; + Color them = ~us; + Square from = from_sq(m); + Square to = to_sq(m); + Piece pc = piece_on(from); + Piece captured = type_of(m) == EN_PASSANT ? make_piece(them, PAWN) : piece_on(to); + + assert(color_of(pc) == us); + assert(captured == NO_PIECE || color_of(captured) == (type_of(m) != CASTLING ? them : us)); + assert(type_of(captured) != KING); + + if (type_of(m) == CASTLING) { + assert(pc == make_piece(us, KING)); + assert(captured == make_piece(us, ROOK)); + + Square rfrom, rto; + do_castling(us, from, to, rfrom, rto); + + k ^= Zobrist::psq[captured][rfrom] ^ Zobrist::psq[captured][rto]; + captured = NO_PIECE; + } + + if (captured) { + Square capsq = to; + + // If the captured piece is a pawn, update pawn hash key, otherwise + // update non-pawn material. + if (type_of(captured) == PAWN) { + if (type_of(m) == EN_PASSANT) { + capsq -= pawn_push(us); + + assert(pc == make_piece(us, PAWN)); + assert(to == st->epSquare); + assert(relative_rank(us, to) == RANK_6); + assert(piece_on(to) == NO_PIECE); + assert(piece_on(capsq) == make_piece(them, PAWN)); + } + } else + st->nonPawnMaterial[them] -= PieceValue[captured]; + + dp.dirty_num = 2; // 1 piece moved, 1 piece captured + dp.piece[1] = captured; + dp.from[1] = capsq; + dp.to[1] = SQ_NONE; + + // Update board and piece lists + remove_piece(capsq); + + // Update material hash key and prefetch access to materialTable + k ^= Zobrist::psq[captured][capsq]; + st->materialKey ^= Zobrist::psq[captured][pieceCount[captured]]; + + // Reset rule 50 counter + st->rule50 = 0; + } + + // Update hash key + k ^= Zobrist::psq[pc][from] ^ Zobrist::psq[pc][to]; + + // Reset en passant square + if (st->epSquare != SQ_NONE) { + k ^= Zobrist::enpassant[file_of(st->epSquare)]; + st->epSquare = SQ_NONE; + } + + // Update castling rights if needed + if (st->castlingRights && (castlingRightsMask[from] | castlingRightsMask[to])) { + k ^= Zobrist::castling[st->castlingRights]; + st->castlingRights &= ~(castlingRightsMask[from] | castlingRightsMask[to]); + k ^= Zobrist::castling[st->castlingRights]; + } + + // Move the piece. The tricky Chess960 castling is handled earlier + if (type_of(m) != CASTLING) { + dp.piece[0] = pc; + dp.from[0] = from; + dp.to[0] = to; + + move_piece(from, to); + } + + // If the moving piece is a pawn do some special extra work + if (type_of(pc) == PAWN) { + // Set en passant square if the moved pawn can be captured + if ((int(to) ^ int(from)) == 16 && + (pawn_attacks_bb(us, to - pawn_push(us)) & pieces(them, PAWN))) { + st->epSquare = to - pawn_push(us); + k ^= Zobrist::enpassant[file_of(st->epSquare)]; + } + + else if (type_of(m) == PROMOTION) { + Piece promotion = make_piece(us, promotion_type(m)); + + assert(relative_rank(us, to) == RANK_8); + assert(type_of(promotion) >= KNIGHT && type_of(promotion) <= QUEEN); + + remove_piece(to); + put_piece(promotion, to); + + // Promoting pawn to SQ_NONE, promoted piece from SQ_NONE + dp.to[0] = SQ_NONE; + dp.piece[dp.dirty_num] = promotion; + dp.from[dp.dirty_num] = SQ_NONE; + dp.to[dp.dirty_num] = to; + dp.dirty_num++; + + // Update hash keys + k ^= Zobrist::psq[pc][to] ^ Zobrist::psq[promotion][to]; + st->materialKey ^= Zobrist::psq[promotion][pieceCount[promotion] - 1] ^ + Zobrist::psq[pc][pieceCount[pc]]; + + // Update material + st->nonPawnMaterial[us] += PieceValue[promotion]; + } + + // Reset rule 50 draw counter + st->rule50 = 0; + } + + // Set capture piece + st->capturedPiece = captured; + + // Update the key with the final value + st->key = k; + + // Calculate checkers bitboard (if move gives check) + st->checkersBB = givesCheck ? attackers_to(square(them)) & pieces(us) : 0; + + sideToMove = ~sideToMove; + + // Update king attacks used for fast check detection + set_check_info(); + + // Calculate the repetition info. It is the ply distance from the previous + // occurrence of the same position, negative in the 3-fold case, or zero + // if the position was not repeated. + st->repetition = 0; + int end = std::min(st->rule50, st->pliesFromNull); + if (end >= 4) { + StateInfo* stp = st->previous->previous; + for (int i = 4; i <= end; i += 2) { + stp = stp->previous->previous; + if (stp->key == st->key) { + st->repetition = stp->repetition ? -i : i; + break; + } + } + } + + assert(pos_is_ok()); + } - st->repetition = 0; - assert(pos_is_ok()); -} + /// Position::undo_move() unmakes a move. When it returns, the position should + /// be restored to exactly the same state as before the move was made. + void Position::undo_move(Move m) { -/// Position::undo_null_move() must be used to undo a "null move" + assert(is_ok(m)); -void Position::undo_null_move() { + sideToMove = ~sideToMove; - assert(!checkers()); + Color us = sideToMove; + Square from = from_sq(m); + Square to = to_sq(m); + Piece pc = piece_on(to); - st = st->previous; - sideToMove = ~sideToMove; -} + assert(empty(from) || type_of(m) == CASTLING); + assert(type_of(st->capturedPiece) != KING); + if (type_of(m) == PROMOTION) { + assert(relative_rank(us, to) == RANK_8); + assert(type_of(pc) == promotion_type(m)); + assert(type_of(pc) >= KNIGHT && type_of(pc) <= QUEEN); -/// Position::key_after() computes the new hash key after the given move. Needed -/// for speculative prefetch. It doesn't recognize special moves like castling, -/// en passant and promotions. + remove_piece(to); + pc = make_piece(us, PAWN); + put_piece(pc, to); + } -Key Position::key_after(Move m) const { + if (type_of(m) == CASTLING) { + Square rfrom, rto; + do_castling(us, from, to, rfrom, rto); + } else { + move_piece(to, from); // Put the piece back at the source square - Square from = from_sq(m); - Square to = to_sq(m); - Piece pc = piece_on(from); - Piece captured = piece_on(to); - Key k = st->key ^ Zobrist::side; + if (st->capturedPiece) { + Square capsq = to; - if (captured) - k ^= Zobrist::psq[captured][to]; + if (type_of(m) == EN_PASSANT) { + capsq -= pawn_push(us); - k ^= Zobrist::psq[pc][to] ^ Zobrist::psq[pc][from]; + assert(type_of(pc) == PAWN); + assert(to == st->previous->epSquare); + assert(relative_rank(us, to) == RANK_6); + assert(piece_on(capsq) == NO_PIECE); + assert(st->capturedPiece == make_piece(~us, PAWN)); + } - return (captured || type_of(pc) == PAWN) - ? k : adjust_key50(k); -} + put_piece(st->capturedPiece, capsq); // Restore the captured piece + } + } + // Finally point our state pointer back to the previous state + st = st->previous; + --gamePly; -/// Position::see_ge (Static Exchange Evaluation Greater or Equal) tests if the -/// SEE value of move is greater or equal to the given threshold. We'll use an -/// algorithm similar to alpha-beta pruning with a null window. + assert(pos_is_ok()); + } -bool Position::see_ge(Move m, Bitboard& occupied, Value threshold) const { - assert(is_ok(m)); + /// Position::do_castling() is a helper used to do/undo a castling move. This + /// is a bit tricky in Chess960 where from/to squares can overlap. + template + void Position::do_castling(Color us, Square from, Square& to, Square& rfrom, Square& rto) { + + bool kingSide = to > from; + rfrom = to; // Castling is encoded as "king captures friendly rook" + rto = relative_square(us, kingSide ? SQ_F1 : SQ_D1); + to = relative_square(us, kingSide ? SQ_G1 : SQ_C1); + + if (Do) { + auto& dp = st->dirtyPiece; + dp.piece[0] = make_piece(us, KING); + dp.from[0] = from; + dp.to[0] = to; + dp.piece[1] = make_piece(us, ROOK); + dp.from[1] = rfrom; + dp.to[1] = rto; + dp.dirty_num = 2; + } + + // Remove both pieces first since squares could overlap in Chess960 + remove_piece(Do ? from : to); + remove_piece(Do ? rfrom : rto); + board[Do ? from : to] = board[Do ? rfrom : rto] = + NO_PIECE; // Since remove_piece doesn't do this for us + put_piece(make_piece(us, KING), Do ? to : from); + put_piece(make_piece(us, ROOK), Do ? rto : rfrom); + } + - // Only deal with normal moves, assume others pass a simple SEE - if (type_of(m) != NORMAL) - return VALUE_ZERO >= threshold; + /// Position::do_null_move() is used to do a "null move": it flips + /// the side to move without executing any move on the board. - Square from = from_sq(m), to = to_sq(m); + void Position::do_null_move(StateInfo& newSt) { - int swap = PieceValue[piece_on(to)] - threshold; - if (swap < 0) - return false; + assert(!checkers()); + assert(&newSt != st); - swap = PieceValue[piece_on(from)] - swap; - if (swap <= 0) - return true; + std::memcpy(&newSt, st, offsetof(StateInfo, accumulator)); - assert(color_of(piece_on(from)) == sideToMove); - occupied = pieces() ^ from ^ to; // xoring to is important for pinned piece logic - Color stm = sideToMove; - Bitboard attackers = attackers_to(to, occupied); - Bitboard stmAttackers, bb; - int res = 1; + newSt.previous = st; + st = &newSt; - while (true) - { - stm = ~stm; - attackers &= occupied; + st->dirtyPiece.dirty_num = 0; + st->dirtyPiece.piece[0] = NO_PIECE; // Avoid checks in UpdateAccumulator() + st->accumulator.computed[WHITE] = false; + st->accumulator.computed[BLACK] = false; - // If stm has no more attackers then give up: stm loses - if (!(stmAttackers = attackers & pieces(stm))) - break; + if (st->epSquare != SQ_NONE) { + st->key ^= Zobrist::enpassant[file_of(st->epSquare)]; + st->epSquare = SQ_NONE; + } - // Don't allow pinned pieces to attack as long as there are - // pinners on their original square. - if (pinners(~stm) & occupied) - { - stmAttackers &= ~blockers_for_king(stm); + st->key ^= Zobrist::side; + ++st->rule50; + prefetch(TT.first_entry(key())); - if (!stmAttackers) - break; - } + st->pliesFromNull = 0; - res ^= 1; + sideToMove = ~sideToMove; - // Locate and remove the next least valuable attacker, and add to - // the bitboard 'attackers' any X-ray attackers behind it. - if ((bb = stmAttackers & pieces(PAWN))) - { - occupied ^= least_significant_square_bb(bb); - if ((swap = PawnValue - swap) < res) - break; + set_check_info(); - attackers |= attacks_bb(to, occupied) & pieces(BISHOP, QUEEN); - } + st->repetition = 0; + + assert(pos_is_ok()); + } - else if ((bb = stmAttackers & pieces(KNIGHT))) - { - occupied ^= least_significant_square_bb(bb); - if ((swap = KnightValue - swap) < res) - break; - } - else if ((bb = stmAttackers & pieces(BISHOP))) - { - occupied ^= least_significant_square_bb(bb); - if ((swap = BishopValue - swap) < res) - break; + /// Position::undo_null_move() must be used to undo a "null move" - attackers |= attacks_bb(to, occupied) & pieces(BISHOP, QUEEN); - } + void Position::undo_null_move() { - else if ((bb = stmAttackers & pieces(ROOK))) - { - occupied ^= least_significant_square_bb(bb); - if ((swap = RookValue - swap) < res) - break; + assert(!checkers()); - attackers |= attacks_bb(to, occupied) & pieces(ROOK, QUEEN); - } + st = st->previous; + sideToMove = ~sideToMove; + } - else if ((bb = stmAttackers & pieces(QUEEN))) - { - occupied ^= least_significant_square_bb(bb); - if ((swap = QueenValue - swap) < res) - break; - attackers |= (attacks_bb(to, occupied) & pieces(BISHOP, QUEEN)) - | (attacks_bb(to, occupied) & pieces(ROOK , QUEEN)); - } + /// Position::key_after() computes the new hash key after the given move. Needed + /// for speculative prefetch. It doesn't recognize special moves like castling, + /// en passant and promotions. - else // KING - // If we "capture" with the king but opponent still has attackers, - // reverse the result. - return (attackers & ~pieces(stm)) ? res ^ 1 : res; - } + Key Position::key_after(Move m) const { - return bool(res); -} + Square from = from_sq(m); + Square to = to_sq(m); + Piece pc = piece_on(from); + Piece captured = piece_on(to); + Key k = st->key ^ Zobrist::side; -bool Position::see_ge(Move m, Value threshold) const { - Bitboard occupied; - return see_ge(m, occupied, threshold); -} - - -/// Position::is_draw() tests whether the position is drawn by 50-move rule -/// or by repetition. It does not detect stalemates. - -bool Position::is_draw(int ply) const { - - if (st->rule50 > 99 && (!checkers() || MoveList(*this).size())) - return true; - - // Return a draw score if a position repeats once earlier but strictly - // after the root, or repeats twice before or at the root. - return st->repetition && st->repetition < ply; -} - - -// Position::has_repeated() tests whether there has been at least one repetition -// of positions since the last capture or pawn move. + if (captured) k ^= Zobrist::psq[captured][to]; -bool Position::has_repeated() const { - - StateInfo* stc = st; - int end = std::min(st->rule50, st->pliesFromNull); - while (end-- >= 4) - { - if (stc->repetition) - return true; + k ^= Zobrist::psq[pc][to] ^ Zobrist::psq[pc][from]; - stc = stc->previous; + return (captured || type_of(pc) == PAWN) ? k : adjust_key50(k); } - return false; -} -/// Position::has_game_cycle() tests if the position has a move which draws by repetition, -/// or an earlier position has a move that directly reaches the current position. + /// Position::see_ge (Static Exchange Evaluation Greater or Equal) tests if the + /// SEE value of move is greater or equal to the given threshold. We'll use an + /// algorithm similar to alpha-beta pruning with a null window. -bool Position::has_game_cycle(int ply) const { + bool Position::see_ge(Move m, Bitboard& occupied, Value threshold) const { - int j; + assert(is_ok(m)); - int end = std::min(st->rule50, st->pliesFromNull); + // Only deal with normal moves, assume others pass a simple SEE + if (type_of(m) != NORMAL) return VALUE_ZERO >= threshold; - if (end < 3) - return false; + Square from = from_sq(m), to = to_sq(m); - Key originalKey = st->key; - StateInfo* stp = st->previous; + int swap = PieceValue[piece_on(to)] - threshold; + if (swap < 0) return false; - for (int i = 3; i <= end; i += 2) - { - stp = stp->previous->previous; + swap = PieceValue[piece_on(from)] - swap; + if (swap <= 0) return true; - Key moveKey = originalKey ^ stp->key; - if ( (j = H1(moveKey), cuckoo[j] == moveKey) - || (j = H2(moveKey), cuckoo[j] == moveKey)) - { - Move move = cuckooMove[j]; - Square s1 = from_sq(move); - Square s2 = to_sq(move); + assert(color_of(piece_on(from)) == sideToMove); + occupied = pieces() ^ from ^ to; // xoring to is important for pinned piece logic + Color stm = sideToMove; + Bitboard attackers = attackers_to(to, occupied); + Bitboard stmAttackers, bb; + int res = 1; - if (!((between_bb(s1, s2) ^ s2) & pieces())) - { - if (ply > i) - return true; + while (true) { + stm = ~stm; + attackers &= occupied; - // For nodes before or at the root, check that the move is a - // repetition rather than a move to the current position. - // In the cuckoo table, both moves Rc1c5 and Rc5c1 are stored in - // the same location, so we have to select which square to check. - if (color_of(piece_on(empty(s1) ? s2 : s1)) != side_to_move()) - continue; + // If stm has no more attackers then give up: stm loses + if (!(stmAttackers = attackers & pieces(stm))) break; - // For repetitions before or at the root, require one more - if (stp->repetition) - return true; - } - } - } - return false; -} + // Don't allow pinned pieces to attack as long as there are + // pinners on their original square. + if (pinners(~stm) & occupied) { + stmAttackers &= ~blockers_for_king(stm); + if (!stmAttackers) break; + } -/// Position::flip() flips position with the white and black sides reversed. This -/// is only useful for debugging e.g. for finding evaluation symmetry bugs. + res ^= 1; -void Position::flip() { + // Locate and remove the next least valuable attacker, and add to + // the bitboard 'attackers' any X-ray attackers behind it. + if ((bb = stmAttackers & pieces(PAWN))) { + occupied ^= least_significant_square_bb(bb); + if ((swap = PawnValue - swap) < res) break; - string f, token; - std::stringstream ss(fen()); + attackers |= attacks_bb(to, occupied) & pieces(BISHOP, QUEEN); + } - for (Rank r = RANK_8; r >= RANK_1; --r) // Piece placement - { - std::getline(ss, token, r > RANK_1 ? '/' : ' '); - f.insert(0, token + (f.empty() ? " " : "/")); - } + else if ((bb = stmAttackers & pieces(KNIGHT))) { + occupied ^= least_significant_square_bb(bb); + if ((swap = KnightValue - swap) < res) break; + } - ss >> token; // Active color - f += (token == "w" ? "B " : "W "); // Will be lowercased later + else if ((bb = stmAttackers & pieces(BISHOP))) { + occupied ^= least_significant_square_bb(bb); + if ((swap = BishopValue - swap) < res) break; - ss >> token; // Castling availability - f += token + " "; + attackers |= attacks_bb(to, occupied) & pieces(BISHOP, QUEEN); + } - std::transform(f.begin(), f.end(), f.begin(), - [](char c) { return char(islower(c) ? toupper(c) : tolower(c)); }); + else if ((bb = stmAttackers & pieces(ROOK))) { + occupied ^= least_significant_square_bb(bb); + if ((swap = RookValue - swap) < res) break; - ss >> token; // En passant square - f += (token == "-" ? token : token.replace(1, 1, token[1] == '3' ? "6" : "3")); + attackers |= attacks_bb(to, occupied) & pieces(ROOK, QUEEN); + } - std::getline(ss, token); // Half and full moves - f += token; + else if ((bb = stmAttackers & pieces(QUEEN))) { + occupied ^= least_significant_square_bb(bb); + if ((swap = QueenValue - swap) < res) break; - set(f, is_chess960(), st, this_thread()); + attackers |= (attacks_bb(to, occupied) & pieces(BISHOP, QUEEN)) | + (attacks_bb(to, occupied) & pieces(ROOK, QUEEN)); + } - assert(pos_is_ok()); -} + else // KING + // If we "capture" with the king but opponent still has attackers, + // reverse the result. + return (attackers & ~pieces(stm)) ? res ^ 1 : res; + } + return bool(res); + } -/// Position::pos_is_ok() performs some consistency checks for the -/// position object and raises an asserts if something wrong is detected. -/// This is meant to be helpful when debugging. + bool Position::see_ge(Move m, Value threshold) const { + Bitboard occupied; + return see_ge(m, occupied, threshold); + } -bool Position::pos_is_ok() const { - constexpr bool Fast = true; // Quick (default) or full check? + /// Position::is_draw() tests whether the position is drawn by 50-move rule + /// or by repetition. It does not detect stalemates. - if ( (sideToMove != WHITE && sideToMove != BLACK) - || piece_on(square(WHITE)) != W_KING - || piece_on(square(BLACK)) != B_KING - || ( ep_square() != SQ_NONE - && relative_rank(sideToMove, ep_square()) != RANK_6)) - assert(0 && "pos_is_ok: Default"); + bool Position::is_draw(int ply) const { - if (Fast) - return true; + if (st->rule50 > 99 && (!checkers() || MoveList(*this).size())) return true; + + // Return a draw score if a position repeats once earlier but strictly + // after the root, or repeats twice before or at the root. + return st->repetition && st->repetition < ply; + } - if ( pieceCount[W_KING] != 1 - || pieceCount[B_KING] != 1 - || attackers_to(square(~sideToMove)) & pieces(sideToMove)) - assert(0 && "pos_is_ok: Kings"); - if ( (pieces(PAWN) & (Rank1BB | Rank8BB)) - || pieceCount[W_PAWN] > 8 - || pieceCount[B_PAWN] > 8) - assert(0 && "pos_is_ok: Pawns"); + // Position::has_repeated() tests whether there has been at least one repetition + // of positions since the last capture or pawn move. + + bool Position::has_repeated() const { + + StateInfo* stc = st; + int end = std::min(st->rule50, st->pliesFromNull); + while (end-- >= 4) { + if (stc->repetition) return true; + + stc = stc->previous; + } + return false; + } - if ( (pieces(WHITE) & pieces(BLACK)) - || (pieces(WHITE) | pieces(BLACK)) != pieces() - || popcount(pieces(WHITE)) > 16 - || popcount(pieces(BLACK)) > 16) - assert(0 && "pos_is_ok: Bitboards"); - for (PieceType p1 = PAWN; p1 <= KING; ++p1) - for (PieceType p2 = PAWN; p2 <= KING; ++p2) - if (p1 != p2 && (pieces(p1) & pieces(p2))) - assert(0 && "pos_is_ok: Bitboards"); + /// Position::has_game_cycle() tests if the position has a move which draws by repetition, + /// or an earlier position has a move that directly reaches the current position. + bool Position::has_game_cycle(int ply) const { - for (Piece pc : Pieces) - if ( pieceCount[pc] != popcount(pieces(color_of(pc), type_of(pc))) - || pieceCount[pc] != std::count(board, board + SQUARE_NB, pc)) - assert(0 && "pos_is_ok: Pieces"); + int j; - for (Color c : { WHITE, BLACK }) - for (CastlingRights cr : {c & KING_SIDE, c & QUEEN_SIDE}) - { - if (!can_castle(cr)) - continue; + int end = std::min(st->rule50, st->pliesFromNull); - if ( piece_on(castlingRookSquare[cr]) != make_piece(c, ROOK) - || castlingRightsMask[castlingRookSquare[cr]] != cr - || (castlingRightsMask[square(c)] & cr) != cr) - assert(0 && "pos_is_ok: Castling"); - } + if (end < 3) return false; - return true; -} + Key originalKey = st->key; + StateInfo* stp = st->previous; + + for (int i = 3; i <= end; i += 2) { + stp = stp->previous->previous; + + Key moveKey = originalKey ^ stp->key; + if ((j = H1(moveKey), cuckoo[j] == moveKey) || + (j = H2(moveKey), cuckoo[j] == moveKey)) { + Move move = cuckooMove[j]; + Square s1 = from_sq(move); + Square s2 = to_sq(move); + + if (!((between_bb(s1, s2) ^ s2) & pieces())) { + if (ply > i) return true; + + // For nodes before or at the root, check that the move is a + // repetition rather than a move to the current position. + // In the cuckoo table, both moves Rc1c5 and Rc5c1 are stored in + // the same location, so we have to select which square to check. + if (color_of(piece_on(empty(s1) ? s2 : s1)) != side_to_move()) continue; + + // For repetitions before or at the root, require one more + if (stp->repetition) return true; + } + } + } + return false; + } + + + /// Position::flip() flips position with the white and black sides reversed. This + /// is only useful for debugging e.g. for finding evaluation symmetry bugs. + + void Position::flip() { + + string f, token; + std::stringstream ss(fen()); + + for (Rank r = RANK_8; r >= RANK_1; --r) // Piece placement + { + std::getline(ss, token, r > RANK_1 ? '/' : ' '); + f.insert(0, token + (f.empty() ? " " : "/")); + } + + ss >> token; // Active color + f += (token == "w" ? "B " : "W "); // Will be lowercased later + + ss >> token; // Castling availability + f += token + " "; + + std::transform(f.begin(), f.end(), f.begin(), + [](char c) { return char(islower(c) ? toupper(c) : tolower(c)); }); + + ss >> token; // En passant square + f += (token == "-" ? token : token.replace(1, 1, token[1] == '3' ? "6" : "3")); + + std::getline(ss, token); // Half and full moves + f += token; + + set(f, is_chess960(), st, this_thread()); + + assert(pos_is_ok()); + } + + + /// Position::pos_is_ok() performs some consistency checks for the + /// position object and raises an asserts if something wrong is detected. + /// This is meant to be helpful when debugging. + + bool Position::pos_is_ok() const { + + constexpr bool Fast = true; // Quick (default) or full check? + + if ((sideToMove != WHITE && sideToMove != BLACK) || + piece_on(square(WHITE)) != W_KING || piece_on(square(BLACK)) != B_KING || + (ep_square() != SQ_NONE && relative_rank(sideToMove, ep_square()) != RANK_6)) + assert(0 && "pos_is_ok: Default"); + + if (Fast) return true; + + if (pieceCount[W_KING] != 1 || pieceCount[B_KING] != 1 || + attackers_to(square(~sideToMove)) & pieces(sideToMove)) + assert(0 && "pos_is_ok: Kings"); + + if ((pieces(PAWN) & (Rank1BB | Rank8BB)) || pieceCount[W_PAWN] > 8 || + pieceCount[B_PAWN] > 8) + assert(0 && "pos_is_ok: Pawns"); + + if ((pieces(WHITE) & pieces(BLACK)) || (pieces(WHITE) | pieces(BLACK)) != pieces() || + popcount(pieces(WHITE)) > 16 || popcount(pieces(BLACK)) > 16) + assert(0 && "pos_is_ok: Bitboards"); + + for (PieceType p1 = PAWN; p1 <= KING; ++p1) + for (PieceType p2 = PAWN; p2 <= KING; ++p2) + if (p1 != p2 && (pieces(p1) & pieces(p2))) assert(0 && "pos_is_ok: Bitboards"); + + + for (Piece pc : Pieces) + if (pieceCount[pc] != popcount(pieces(color_of(pc), type_of(pc))) || + pieceCount[pc] != std::count(board, board + SQUARE_NB, pc)) + assert(0 && "pos_is_ok: Pieces"); + + for (Color c : {WHITE, BLACK}) + for (CastlingRights cr : {c & KING_SIDE, c & QUEEN_SIDE}) { + if (!can_castle(cr)) continue; + + if (piece_on(castlingRookSquare[cr]) != make_piece(c, ROOK) || + castlingRightsMask[castlingRookSquare[cr]] != cr || + (castlingRightsMask[square(c)] & cr) != cr) + assert(0 && "pos_is_ok: Castling"); + } + + return true; + } -} // namespace Stockfish +} // namespace Stockfish diff --git a/src/position.h b/src/position.h index ca7c3ace811..106cf28cff1 100644 --- a/src/position.h +++ b/src/position.h @@ -31,374 +31,321 @@ namespace Stockfish { -/// StateInfo struct stores information needed to restore a Position object to -/// its previous state when we retract a move. Whenever a move is made on the -/// board (by calling Position::do_move), a StateInfo object must be passed. - -struct StateInfo { - - // Copied when making a move - Key materialKey; - Value nonPawnMaterial[COLOR_NB]; - int castlingRights; - int rule50; - int pliesFromNull; - Square epSquare; - - // Not copied when making a move (will be recomputed anyhow) - Key key; - Bitboard checkersBB; - StateInfo* previous; - Bitboard blockersForKing[COLOR_NB]; - Bitboard pinners[COLOR_NB]; - Bitboard checkSquares[PIECE_TYPE_NB]; - Piece capturedPiece; - int repetition; - - // Used by NNUE - Eval::NNUE::Accumulator accumulator; - DirtyPiece dirtyPiece; -}; - - -/// A list to keep track of the position states along the setup moves (from the -/// start position to the position just before the search starts). Needed by -/// 'draw by repetition' detection. Use a std::deque because pointers to -/// elements are not invalidated upon list resizing. -using StateListPtr = std::unique_ptr>; - - -/// Position class stores information regarding the board representation as -/// pieces, side to move, hash keys, castling info, etc. Important methods are -/// do_move() and undo_move(), used by the search to update node info when -/// traversing the search tree. -class Thread; - -class Position { -public: - static void init(); - - Position() = default; - Position(const Position&) = delete; - Position& operator=(const Position&) = delete; - - // FEN string input/output - Position& set(const std::string& fenStr, bool isChess960, StateInfo* si, Thread* th); - Position& set(const std::string& code, Color c, StateInfo* si); - std::string fen() const; - - // Position representation - Bitboard pieces(PieceType pt = ALL_PIECES) const; - template Bitboard pieces(PieceType pt, PieceTypes... pts) const; - Bitboard pieces(Color c) const; - template Bitboard pieces(Color c, PieceTypes... pts) const; - Piece piece_on(Square s) const; - Square ep_square() const; - bool empty(Square s) const; - template int count(Color c) const; - template int count() const; - template Square square(Color c) const; - - // Castling - CastlingRights castling_rights(Color c) const; - bool can_castle(CastlingRights cr) const; - bool castling_impeded(CastlingRights cr) const; - Square castling_rook_square(CastlingRights cr) const; - - // Checking - Bitboard checkers() const; - Bitboard blockers_for_king(Color c) const; - Bitboard check_squares(PieceType pt) const; - Bitboard pinners(Color c) const; - - // Attacks to/from a given square - Bitboard attackers_to(Square s) const; - Bitboard attackers_to(Square s, Bitboard occupied) const; - void update_slider_blockers(Color c) const; - template Bitboard attacks_by(Color c) const; - - // Properties of moves - bool legal(Move m) const; - bool pseudo_legal(const Move m) const; - bool capture(Move m) const; - bool capture_stage(Move m) const; - bool gives_check(Move m) const; - Piece moved_piece(Move m) const; - Piece captured_piece() const; - - // Doing and undoing moves - void do_move(Move m, StateInfo& newSt); - void do_move(Move m, StateInfo& newSt, bool givesCheck); - void undo_move(Move m); - void do_null_move(StateInfo& newSt); - void undo_null_move(); - - // Static Exchange Evaluation - bool see_ge(Move m, Value threshold = VALUE_ZERO) const; - bool see_ge(Move m, Bitboard& occupied, Value threshold = VALUE_ZERO) const; - - // Accessing hash keys - Key key() const; - Key key_after(Move m) const; - Key material_key() const; - - // Other properties of the position - Color side_to_move() const; - int game_ply() const; - bool is_chess960() const; - Thread* this_thread() const; - bool is_draw(int ply) const; - bool has_game_cycle(int ply) const; - bool has_repeated() const; - int rule50_count() const; - Value non_pawn_material(Color c) const; - Value non_pawn_material() const; - - // Position consistency check, for debugging - bool pos_is_ok() const; - void flip(); - - // Used by NNUE - StateInfo* state() const; - - void put_piece(Piece pc, Square s); - void remove_piece(Square s); - -private: - // Initialization helpers (used while setting up a position) - void set_castling_right(Color c, Square rfrom); - void set_state() const; - void set_check_info() const; - - // Other helpers - void move_piece(Square from, Square to); - template - void do_castling(Color us, Square from, Square& to, Square& rfrom, Square& rto); - template - Key adjust_key50(Key k) const; - - // Data members - Piece board[SQUARE_NB]; - Bitboard byTypeBB[PIECE_TYPE_NB]; - Bitboard byColorBB[COLOR_NB]; - int pieceCount[PIECE_NB]; - int castlingRightsMask[SQUARE_NB]; - Square castlingRookSquare[CASTLING_RIGHT_NB]; - Bitboard castlingPath[CASTLING_RIGHT_NB]; - Thread* thisThread; - StateInfo* st; - int gamePly; - Color sideToMove; - bool chess960; -}; - -std::ostream& operator<<(std::ostream& os, const Position& pos); - -inline Color Position::side_to_move() const { - return sideToMove; -} - -inline Piece Position::piece_on(Square s) const { - assert(is_ok(s)); - return board[s]; -} - -inline bool Position::empty(Square s) const { - return piece_on(s) == NO_PIECE; -} - -inline Piece Position::moved_piece(Move m) const { - return piece_on(from_sq(m)); -} - -inline Bitboard Position::pieces(PieceType pt) const { - return byTypeBB[pt]; -} - -template -inline Bitboard Position::pieces(PieceType pt, PieceTypes... pts) const { - return pieces(pt) | pieces(pts...); -} - -inline Bitboard Position::pieces(Color c) const { - return byColorBB[c]; -} - -template -inline Bitboard Position::pieces(Color c, PieceTypes... pts) const { - return pieces(c) & pieces(pts...); -} - -template inline int Position::count(Color c) const { - return pieceCount[make_piece(c, Pt)]; -} - -template inline int Position::count() const { - return count(WHITE) + count(BLACK); -} - -template inline Square Position::square(Color c) const { - assert(count(c) == 1); - return lsb(pieces(c, Pt)); -} - -inline Square Position::ep_square() const { - return st->epSquare; -} - -inline bool Position::can_castle(CastlingRights cr) const { - return st->castlingRights & cr; -} - -inline CastlingRights Position::castling_rights(Color c) const { - return c & CastlingRights(st->castlingRights); -} - -inline bool Position::castling_impeded(CastlingRights cr) const { - assert(cr == WHITE_OO || cr == WHITE_OOO || cr == BLACK_OO || cr == BLACK_OOO); - - return pieces() & castlingPath[cr]; -} - -inline Square Position::castling_rook_square(CastlingRights cr) const { - assert(cr == WHITE_OO || cr == WHITE_OOO || cr == BLACK_OO || cr == BLACK_OOO); - - return castlingRookSquare[cr]; -} - -inline Bitboard Position::attackers_to(Square s) const { - return attackers_to(s, pieces()); -} - -template -inline Bitboard Position::attacks_by(Color c) const { - - if constexpr (Pt == PAWN) - return c == WHITE ? pawn_attacks_bb(pieces(WHITE, PAWN)) - : pawn_attacks_bb(pieces(BLACK, PAWN)); - else - { - Bitboard threats = 0; - Bitboard attackers = pieces(c, Pt); - while (attackers) - threats |= attacks_bb(pop_lsb(attackers), pieces()); - return threats; - } -} - -inline Bitboard Position::checkers() const { - return st->checkersBB; -} - -inline Bitboard Position::blockers_for_king(Color c) const { - return st->blockersForKing[c]; -} - -inline Bitboard Position::pinners(Color c) const { - return st->pinners[c]; -} - -inline Bitboard Position::check_squares(PieceType pt) const { - return st->checkSquares[pt]; -} - -inline Key Position::key() const { - return adjust_key50(st->key); -} - -template -inline Key Position::adjust_key50(Key k) const -{ - return st->rule50 < 14 - AfterMove - ? k : k ^ make_key((st->rule50 - (14 - AfterMove)) / 8); -} - -inline Key Position::material_key() const { - return st->materialKey; -} - -inline Value Position::non_pawn_material(Color c) const { - return st->nonPawnMaterial[c]; -} - -inline Value Position::non_pawn_material() const { - return non_pawn_material(WHITE) + non_pawn_material(BLACK); -} - -inline int Position::game_ply() const { - return gamePly; -} - -inline int Position::rule50_count() const { - return st->rule50; -} - -inline bool Position::is_chess960() const { - return chess960; -} - -inline bool Position::capture(Move m) const { - assert(is_ok(m)); - return (!empty(to_sq(m)) && type_of(m) != CASTLING) - || type_of(m) == EN_PASSANT; -} - -// returns true if a move is generated from the capture stage -// having also queen promotions covered, i.e. consistency with the capture stage move generation -// is needed to avoid the generation of duplicate moves. -inline bool Position::capture_stage(Move m) const { - assert(is_ok(m)); - return capture(m) || promotion_type(m) == QUEEN; -} - -inline Piece Position::captured_piece() const { - return st->capturedPiece; -} - -inline Thread* Position::this_thread() const { - return thisThread; -} - -inline void Position::put_piece(Piece pc, Square s) { - - board[s] = pc; - byTypeBB[ALL_PIECES] |= byTypeBB[type_of(pc)] |= s; - byColorBB[color_of(pc)] |= s; - pieceCount[pc]++; - pieceCount[make_piece(color_of(pc), ALL_PIECES)]++; -} - -inline void Position::remove_piece(Square s) { - - Piece pc = board[s]; - byTypeBB[ALL_PIECES] ^= s; - byTypeBB[type_of(pc)] ^= s; - byColorBB[color_of(pc)] ^= s; - board[s] = NO_PIECE; - pieceCount[pc]--; - pieceCount[make_piece(color_of(pc), ALL_PIECES)]--; -} - -inline void Position::move_piece(Square from, Square to) { - - Piece pc = board[from]; - Bitboard fromTo = from | to; - byTypeBB[ALL_PIECES] ^= fromTo; - byTypeBB[type_of(pc)] ^= fromTo; - byColorBB[color_of(pc)] ^= fromTo; - board[from] = NO_PIECE; - board[to] = pc; -} - -inline void Position::do_move(Move m, StateInfo& newSt) { - do_move(m, newSt, gives_check(m)); -} - -inline StateInfo* Position::state() const { - - return st; -} - -} // namespace Stockfish - -#endif // #ifndef POSITION_H_INCLUDED + /// StateInfo struct stores information needed to restore a Position object to + /// its previous state when we retract a move. Whenever a move is made on the + /// board (by calling Position::do_move), a StateInfo object must be passed. + + struct StateInfo { + + // Copied when making a move + Key materialKey; + Value nonPawnMaterial[COLOR_NB]; + int castlingRights; + int rule50; + int pliesFromNull; + Square epSquare; + + // Not copied when making a move (will be recomputed anyhow) + Key key; + Bitboard checkersBB; + StateInfo* previous; + Bitboard blockersForKing[COLOR_NB]; + Bitboard pinners[COLOR_NB]; + Bitboard checkSquares[PIECE_TYPE_NB]; + Piece capturedPiece; + int repetition; + + // Used by NNUE + Eval::NNUE::Accumulator accumulator; + DirtyPiece dirtyPiece; + }; + + + /// A list to keep track of the position states along the setup moves (from the + /// start position to the position just before the search starts). Needed by + /// 'draw by repetition' detection. Use a std::deque because pointers to + /// elements are not invalidated upon list resizing. + using StateListPtr = std::unique_ptr>; + + + /// Position class stores information regarding the board representation as + /// pieces, side to move, hash keys, castling info, etc. Important methods are + /// do_move() and undo_move(), used by the search to update node info when + /// traversing the search tree. + class Thread; + + class Position { + public: + static void init(); + + Position() = default; + Position(const Position&) = delete; + Position& operator=(const Position&) = delete; + + // FEN string input/output + Position& set(const std::string& fenStr, bool isChess960, StateInfo* si, Thread* th); + Position& set(const std::string& code, Color c, StateInfo* si); + std::string fen() const; + + // Position representation + Bitboard pieces(PieceType pt = ALL_PIECES) const; + template Bitboard pieces(PieceType pt, PieceTypes... pts) const; + Bitboard pieces(Color c) const; + template Bitboard pieces(Color c, PieceTypes... pts) const; + Piece piece_on(Square s) const; + Square ep_square() const; + bool empty(Square s) const; + template int count(Color c) const; + template int count() const; + template Square square(Color c) const; + + // Castling + CastlingRights castling_rights(Color c) const; + bool can_castle(CastlingRights cr) const; + bool castling_impeded(CastlingRights cr) const; + Square castling_rook_square(CastlingRights cr) const; + + // Checking + Bitboard checkers() const; + Bitboard blockers_for_king(Color c) const; + Bitboard check_squares(PieceType pt) const; + Bitboard pinners(Color c) const; + + // Attacks to/from a given square + Bitboard attackers_to(Square s) const; + Bitboard attackers_to(Square s, Bitboard occupied) const; + void update_slider_blockers(Color c) const; + template Bitboard attacks_by(Color c) const; + + // Properties of moves + bool legal(Move m) const; + bool pseudo_legal(const Move m) const; + bool capture(Move m) const; + bool capture_stage(Move m) const; + bool gives_check(Move m) const; + Piece moved_piece(Move m) const; + Piece captured_piece() const; + + // Doing and undoing moves + void do_move(Move m, StateInfo& newSt); + void do_move(Move m, StateInfo& newSt, bool givesCheck); + void undo_move(Move m); + void do_null_move(StateInfo& newSt); + void undo_null_move(); + + // Static Exchange Evaluation + bool see_ge(Move m, Value threshold = VALUE_ZERO) const; + bool see_ge(Move m, Bitboard& occupied, Value threshold = VALUE_ZERO) const; + + // Accessing hash keys + Key key() const; + Key key_after(Move m) const; + Key material_key() const; + + // Other properties of the position + Color side_to_move() const; + int game_ply() const; + bool is_chess960() const; + Thread* this_thread() const; + bool is_draw(int ply) const; + bool has_game_cycle(int ply) const; + bool has_repeated() const; + int rule50_count() const; + Value non_pawn_material(Color c) const; + Value non_pawn_material() const; + + // Position consistency check, for debugging + bool pos_is_ok() const; + void flip(); + + // Used by NNUE + StateInfo* state() const; + + void put_piece(Piece pc, Square s); + void remove_piece(Square s); + + private: + // Initialization helpers (used while setting up a position) + void set_castling_right(Color c, Square rfrom); + void set_state() const; + void set_check_info() const; + + // Other helpers + void move_piece(Square from, Square to); + template + void do_castling(Color us, Square from, Square& to, Square& rfrom, Square& rto); + template Key adjust_key50(Key k) const; + + // Data members + Piece board[SQUARE_NB]; + Bitboard byTypeBB[PIECE_TYPE_NB]; + Bitboard byColorBB[COLOR_NB]; + int pieceCount[PIECE_NB]; + int castlingRightsMask[SQUARE_NB]; + Square castlingRookSquare[CASTLING_RIGHT_NB]; + Bitboard castlingPath[CASTLING_RIGHT_NB]; + Thread* thisThread; + StateInfo* st; + int gamePly; + Color sideToMove; + bool chess960; + }; + + std::ostream& operator<<(std::ostream& os, const Position& pos); + + inline Color Position::side_to_move() const { return sideToMove; } + + inline Piece Position::piece_on(Square s) const { + assert(is_ok(s)); + return board[s]; + } + + inline bool Position::empty(Square s) const { return piece_on(s) == NO_PIECE; } + + inline Piece Position::moved_piece(Move m) const { return piece_on(from_sq(m)); } + + inline Bitboard Position::pieces(PieceType pt) const { return byTypeBB[pt]; } + + template + inline Bitboard Position::pieces(PieceType pt, PieceTypes... pts) const { + return pieces(pt) | pieces(pts...); + } + + inline Bitboard Position::pieces(Color c) const { return byColorBB[c]; } + + template + inline Bitboard Position::pieces(Color c, PieceTypes... pts) const { + return pieces(c) & pieces(pts...); + } + + template inline int Position::count(Color c) const { + return pieceCount[make_piece(c, Pt)]; + } + + template inline int Position::count() const { + return count(WHITE) + count(BLACK); + } + + template inline Square Position::square(Color c) const { + assert(count(c) == 1); + return lsb(pieces(c, Pt)); + } + + inline Square Position::ep_square() const { return st->epSquare; } + + inline bool Position::can_castle(CastlingRights cr) const { return st->castlingRights & cr; } + + inline CastlingRights Position::castling_rights(Color c) const { + return c & CastlingRights(st->castlingRights); + } + + inline bool Position::castling_impeded(CastlingRights cr) const { + assert(cr == WHITE_OO || cr == WHITE_OOO || cr == BLACK_OO || cr == BLACK_OOO); + + return pieces() & castlingPath[cr]; + } + + inline Square Position::castling_rook_square(CastlingRights cr) const { + assert(cr == WHITE_OO || cr == WHITE_OOO || cr == BLACK_OO || cr == BLACK_OOO); + + return castlingRookSquare[cr]; + } + + inline Bitboard Position::attackers_to(Square s) const { return attackers_to(s, pieces()); } + + template inline Bitboard Position::attacks_by(Color c) const { + + if constexpr (Pt == PAWN) + return c == WHITE ? pawn_attacks_bb(pieces(WHITE, PAWN)) : + pawn_attacks_bb(pieces(BLACK, PAWN)); + else { + Bitboard threats = 0; + Bitboard attackers = pieces(c, Pt); + while (attackers) threats |= attacks_bb(pop_lsb(attackers), pieces()); + return threats; + } + } + + inline Bitboard Position::checkers() const { return st->checkersBB; } + + inline Bitboard Position::blockers_for_king(Color c) const { return st->blockersForKing[c]; } + + inline Bitboard Position::pinners(Color c) const { return st->pinners[c]; } + + inline Bitboard Position::check_squares(PieceType pt) const { return st->checkSquares[pt]; } + + inline Key Position::key() const { return adjust_key50(st->key); } + + template inline Key Position::adjust_key50(Key k) const { + return st->rule50 < 14 - AfterMove ? k : k ^ make_key((st->rule50 - (14 - AfterMove)) / 8); + } + + inline Key Position::material_key() const { return st->materialKey; } + + inline Value Position::non_pawn_material(Color c) const { return st->nonPawnMaterial[c]; } + + inline Value Position::non_pawn_material() const { + return non_pawn_material(WHITE) + non_pawn_material(BLACK); + } + + inline int Position::game_ply() const { return gamePly; } + + inline int Position::rule50_count() const { return st->rule50; } + + inline bool Position::is_chess960() const { return chess960; } + + inline bool Position::capture(Move m) const { + assert(is_ok(m)); + return (!empty(to_sq(m)) && type_of(m) != CASTLING) || type_of(m) == EN_PASSANT; + } + + // returns true if a move is generated from the capture stage + // having also queen promotions covered, i.e. consistency with the capture stage move generation + // is needed to avoid the generation of duplicate moves. + inline bool Position::capture_stage(Move m) const { + assert(is_ok(m)); + return capture(m) || promotion_type(m) == QUEEN; + } + + inline Piece Position::captured_piece() const { return st->capturedPiece; } + + inline Thread* Position::this_thread() const { return thisThread; } + + inline void Position::put_piece(Piece pc, Square s) { + + board[s] = pc; + byTypeBB[ALL_PIECES] |= byTypeBB[type_of(pc)] |= s; + byColorBB[color_of(pc)] |= s; + pieceCount[pc]++; + pieceCount[make_piece(color_of(pc), ALL_PIECES)]++; + } + + inline void Position::remove_piece(Square s) { + + Piece pc = board[s]; + byTypeBB[ALL_PIECES] ^= s; + byTypeBB[type_of(pc)] ^= s; + byColorBB[color_of(pc)] ^= s; + board[s] = NO_PIECE; + pieceCount[pc]--; + pieceCount[make_piece(color_of(pc), ALL_PIECES)]--; + } + + inline void Position::move_piece(Square from, Square to) { + + Piece pc = board[from]; + Bitboard fromTo = from | to; + byTypeBB[ALL_PIECES] ^= fromTo; + byTypeBB[type_of(pc)] ^= fromTo; + byColorBB[color_of(pc)] ^= fromTo; + board[from] = NO_PIECE; + board[to] = pc; + } + + inline void Position::do_move(Move m, StateInfo& newSt) { do_move(m, newSt, gives_check(m)); } + + inline StateInfo* Position::state() const { return st; } + +} // namespace Stockfish + +#endif // #ifndef POSITION_H_INCLUDED diff --git a/src/search.cpp b/src/search.cpp index 4b403c49d8e..95ae93feb0b 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -47,1961 +47,1764 @@ namespace Stockfish { -namespace Search { - - LimitsType Limits; -} - -namespace Tablebases { - - int Cardinality; - bool RootInTB; - bool UseRule50; - Depth ProbeDepth; -} - -namespace TB = Tablebases; - -using std::string; -using Eval::evaluate; -using namespace Search; - -namespace { - - // Different node types, used as a template parameter - enum NodeType { NonPV, PV, Root }; - - // Futility margin - Value futility_margin(Depth d, bool noTtCutNode, bool improving) { - return Value((140 - 40 * noTtCutNode) * (d - improving)); - } - - // Reductions lookup table initialized at startup - int Reductions[MAX_MOVES]; // [depth or moveNumber] - - Depth reduction(bool i, Depth d, int mn, Value delta, Value rootDelta) { - int reductionScale = Reductions[d] * Reductions[mn]; - return (reductionScale + 1372 - int(delta) * 1073 / int(rootDelta)) / 1024 - + (!i && reductionScale > 936); - } - - constexpr int futility_move_count(bool improving, Depth depth) { - return improving ? (3 + depth * depth) - : (3 + depth * depth) / 2; - } - - // History and stats update bonus, based on depth - int stat_bonus(Depth d) { - return std::min(336 * d - 547, 1561); - } - - // Add a small random component to draw evaluations to avoid 3-fold blindness - Value value_draw(const Thread* thisThread) { - return VALUE_DRAW - 1 + Value(thisThread->nodes & 0x2); - } - - // Skill structure is used to implement strength limit. If we have an uci_elo then - // we convert it to a suitable fractional skill level using anchoring to CCRL Elo - // (goldfish 1.13 = 2000) and a fit through Ordo derived Elo for a match (TC 60+0.6) - // results spanning a wide range of k values. - struct Skill { - Skill(int skill_level, int uci_elo) { - if (uci_elo) - { - double e = double(uci_elo - 1320) / (3190 - 1320); - level = std::clamp((((37.2473 * e - 40.8525) * e + 22.2943) * e - 0.311438), 0.0, 19.0); - } - else - level = double(skill_level); + namespace Search { + + LimitsType Limits; } - bool enabled() const { return level < 20.0; } - bool time_to_pick(Depth depth) const { return depth == 1 + int(level); } - Move pick_best(size_t multiPV); - - double level; - Move best = MOVE_NONE; - }; - - template - Value search(Position& pos, Stack* ss, Value alpha, Value beta, Depth depth, bool cutNode); - - template - Value qsearch(Position& pos, Stack* ss, Value alpha, Value beta, Depth depth = 0); - - Value value_to_tt(Value v, int ply); - Value value_from_tt(Value v, int ply, int r50c); - void update_pv(Move* pv, Move move, const Move* childPv); - void update_continuation_histories(Stack* ss, Piece pc, Square to, int bonus); - void update_quiet_stats(const Position& pos, Stack* ss, Move move, int bonus); - void update_all_stats(const Position& pos, Stack* ss, Move bestMove, Value bestValue, Value beta, Square prevSq, - Move* quietsSearched, int quietCount, Move* capturesSearched, int captureCount, Depth depth); - - // perft() is our utility to verify move generation. All the leaf nodes up - // to the given depth are generated and counted, and the sum is returned. - template - uint64_t perft(Position& pos, Depth depth) { - - StateInfo st; - ASSERT_ALIGNED(&st, Eval::NNUE::CacheLineSize); - - uint64_t cnt, nodes = 0; - const bool leaf = (depth == 2); - - for (const auto& m : MoveList(pos)) - { - if (Root && depth <= 1) - cnt = 1, nodes++; - else - { - pos.do_move(m, st); - cnt = leaf ? MoveList(pos).size() : perft(pos, depth - 1); - nodes += cnt; - pos.undo_move(m); - } - if (Root) - sync_cout << UCI::move(m, pos.is_chess960()) << ": " << cnt << sync_endl; + + namespace Tablebases { + + int Cardinality; + bool RootInTB; + bool UseRule50; + Depth ProbeDepth; } - return nodes; - } -} // namespace + namespace TB = Tablebases; + + using std::string; + using Eval::evaluate; + using namespace Search; + + namespace { + + // Different node types, used as a template parameter + enum NodeType { + NonPV, + PV, + Root + }; + + // Futility margin + Value futility_margin(Depth d, bool noTtCutNode, bool improving) { + return Value((140 - 40 * noTtCutNode) * (d - improving)); + } + + // Reductions lookup table initialized at startup + int Reductions[MAX_MOVES]; // [depth or moveNumber] + + Depth reduction(bool i, Depth d, int mn, Value delta, Value rootDelta) { + int reductionScale = Reductions[d] * Reductions[mn]; + return (reductionScale + 1372 - int(delta) * 1073 / int(rootDelta)) / 1024 + + (!i && reductionScale > 936); + } + + constexpr int futility_move_count(bool improving, Depth depth) { + return improving ? (3 + depth * depth) : (3 + depth * depth) / 2; + } + + // History and stats update bonus, based on depth + int stat_bonus(Depth d) { return std::min(336 * d - 547, 1561); } + // Add a small random component to draw evaluations to avoid 3-fold blindness + Value value_draw(const Thread* thisThread) { + return VALUE_DRAW - 1 + Value(thisThread->nodes & 0x2); + } -/// Search::init() is called at startup to initialize various lookup tables + // Skill structure is used to implement strength limit. If we have an uci_elo then + // we convert it to a suitable fractional skill level using anchoring to CCRL Elo + // (goldfish 1.13 = 2000) and a fit through Ordo derived Elo for a match (TC 60+0.6) + // results spanning a wide range of k values. + struct Skill { + Skill(int skill_level, int uci_elo) { + if (uci_elo) { + double e = double(uci_elo - 1320) / (3190 - 1320); + level = std::clamp((((37.2473 * e - 40.8525) * e + 22.2943) * e - 0.311438), + 0.0, 19.0); + } else + level = double(skill_level); + } + bool enabled() const { return level < 20.0; } + bool time_to_pick(Depth depth) const { return depth == 1 + int(level); } + Move pick_best(size_t multiPV); + + double level; + Move best = MOVE_NONE; + }; + + template + Value search(Position& pos, Stack* ss, Value alpha, Value beta, Depth depth, bool cutNode); + + template + Value qsearch(Position& pos, Stack* ss, Value alpha, Value beta, Depth depth = 0); + + Value value_to_tt(Value v, int ply); + Value value_from_tt(Value v, int ply, int r50c); + void update_pv(Move* pv, Move move, const Move* childPv); + void update_continuation_histories(Stack* ss, Piece pc, Square to, int bonus); + void update_quiet_stats(const Position& pos, Stack* ss, Move move, int bonus); + void update_all_stats(const Position& pos, Stack* ss, Move bestMove, Value bestValue, + Value beta, Square prevSq, Move* quietsSearched, int quietCount, + Move* capturesSearched, int captureCount, Depth depth); + + // perft() is our utility to verify move generation. All the leaf nodes up + // to the given depth are generated and counted, and the sum is returned. + template uint64_t perft(Position& pos, Depth depth) { + + StateInfo st; + ASSERT_ALIGNED(&st, Eval::NNUE::CacheLineSize); + + uint64_t cnt, nodes = 0; + const bool leaf = (depth == 2); + + for (const auto& m : MoveList(pos)) { + if (Root && depth <= 1) + cnt = 1, nodes++; + else { + pos.do_move(m, st); + cnt = leaf ? MoveList(pos).size() : perft(pos, depth - 1); + nodes += cnt; + pos.undo_move(m); + } + if (Root) sync_cout << UCI::move(m, pos.is_chess960()) << ": " << cnt << sync_endl; + } + return nodes; + } -void Search::init() { + } // namespace - for (int i = 1; i < MAX_MOVES; ++i) - Reductions[i] = int((20.57 + std::log(Threads.size()) / 2) * std::log(i)); -} + /// Search::init() is called at startup to initialize various lookup tables -/// Search::clear() resets search state to its initial value + void Search::init() { -void Search::clear() { + for (int i = 1; i < MAX_MOVES; ++i) + Reductions[i] = int((20.57 + std::log(Threads.size()) / 2) * std::log(i)); + } - Threads.main()->wait_for_search_finished(); - Time.availableNodes = 0; - TT.clear(); - Threads.clear(); - Tablebases::init(Options["SyzygyPath"]); // Free mapped files -} + /// Search::clear() resets search state to its initial value + void Search::clear() { -/// MainThread::search() is started when the program receives the UCI 'go' -/// command. It searches from the root position and outputs the "bestmove". + Threads.main()->wait_for_search_finished(); -void MainThread::search() { + Time.availableNodes = 0; + TT.clear(); + Threads.clear(); + Tablebases::init(Options["SyzygyPath"]); // Free mapped files + } - if (Limits.perft) - { - nodes = perft(rootPos, Limits.perft); - sync_cout << "\nNodes searched: " << nodes << "\n" << sync_endl; - return; - } - Color us = rootPos.side_to_move(); - Time.init(Limits, us, rootPos.game_ply()); - TT.new_search(); + /// MainThread::search() is started when the program receives the UCI 'go' + /// command. It searches from the root position and outputs the "bestmove". - Eval::NNUE::verify(); + void MainThread::search() { - if (rootMoves.empty()) - { - rootMoves.emplace_back(MOVE_NONE); - sync_cout << "info depth 0 score " - << UCI::value(rootPos.checkers() ? -VALUE_MATE : VALUE_DRAW) - << sync_endl; - } - else - { - Threads.start_searching(); // start non-main threads - Thread::search(); // main thread start searching - } + if (Limits.perft) { + nodes = perft(rootPos, Limits.perft); + sync_cout << "\nNodes searched: " << nodes << "\n" << sync_endl; + return; + } + + Color us = rootPos.side_to_move(); + Time.init(Limits, us, rootPos.game_ply()); + TT.new_search(); - // When we reach the maximum depth, we can arrive here without a raise of - // Threads.stop. However, if we are pondering or in an infinite search, - // the UCI protocol states that we shouldn't print the best move before the - // GUI sends a "stop" or "ponderhit" command. We therefore simply wait here - // until the GUI sends one of those commands. + Eval::NNUE::verify(); - while (!Threads.stop && (ponder || Limits.infinite)) - {} // Busy wait for a stop or a ponder reset + if (rootMoves.empty()) { + rootMoves.emplace_back(MOVE_NONE); + sync_cout << "info depth 0 score " + << UCI::value(rootPos.checkers() ? -VALUE_MATE : VALUE_DRAW) << sync_endl; + } else { + Threads.start_searching(); // start non-main threads + Thread::search(); // main thread start searching + } - // Stop the threads if not already stopped (also raise the stop if - // "ponderhit" just reset Threads.ponder). - Threads.stop = true; + // When we reach the maximum depth, we can arrive here without a raise of + // Threads.stop. However, if we are pondering or in an infinite search, + // the UCI protocol states that we shouldn't print the best move before the + // GUI sends a "stop" or "ponderhit" command. We therefore simply wait here + // until the GUI sends one of those commands. - // Wait until all threads have finished - Threads.wait_for_search_finished(); + while (!Threads.stop && (ponder || Limits.infinite)) { + } // Busy wait for a stop or a ponder reset - // When playing in 'nodes as time' mode, subtract the searched nodes from - // the available ones before exiting. - if (Limits.npmsec) - Time.availableNodes += Limits.inc[us] - Threads.nodes_searched(); + // Stop the threads if not already stopped (also raise the stop if + // "ponderhit" just reset Threads.ponder). + Threads.stop = true; - Thread* bestThread = this; - Skill skill = Skill(Options["Skill Level"], Options["UCI_LimitStrength"] ? int(Options["UCI_Elo"]) : 0); + // Wait until all threads have finished + Threads.wait_for_search_finished(); - if ( int(Options["MultiPV"]) == 1 - && !Limits.depth - && !skill.enabled() - && rootMoves[0].pv[0] != MOVE_NONE) - bestThread = Threads.get_best_thread(); + // When playing in 'nodes as time' mode, subtract the searched nodes from + // the available ones before exiting. + if (Limits.npmsec) Time.availableNodes += Limits.inc[us] - Threads.nodes_searched(); - bestPreviousScore = bestThread->rootMoves[0].score; - bestPreviousAverageScore = bestThread->rootMoves[0].averageScore; + Thread* bestThread = this; + Skill skill = + Skill(Options["Skill Level"], Options["UCI_LimitStrength"] ? int(Options["UCI_Elo"]) : 0); - // Send again PV info if we have a new best thread - if (bestThread != this) - sync_cout << UCI::pv(bestThread->rootPos, bestThread->completedDepth) << sync_endl; + if (int(Options["MultiPV"]) == 1 && !Limits.depth && !skill.enabled() && + rootMoves[0].pv[0] != MOVE_NONE) + bestThread = Threads.get_best_thread(); - sync_cout << "bestmove " << UCI::move(bestThread->rootMoves[0].pv[0], rootPos.is_chess960()); + bestPreviousScore = bestThread->rootMoves[0].score; + bestPreviousAverageScore = bestThread->rootMoves[0].averageScore; - if (bestThread->rootMoves[0].pv.size() > 1 || bestThread->rootMoves[0].extract_ponder_from_tt(rootPos)) - std::cout << " ponder " << UCI::move(bestThread->rootMoves[0].pv[1], rootPos.is_chess960()); + // Send again PV info if we have a new best thread + if (bestThread != this) + sync_cout << UCI::pv(bestThread->rootPos, bestThread->completedDepth) << sync_endl; - std::cout << sync_endl; -} - - -/// Thread::search() is the main iterative deepening loop. It calls search() -/// repeatedly with increasing depth until the allocated thinking time has been -/// consumed, the user stops the search, or the maximum search depth is reached. - -void Thread::search() { - - // To allow access to (ss-7) up to (ss+2), the stack must be oversized. - // The former is needed to allow update_continuation_histories(ss-1, ...), - // which accesses its argument at ss-6, also near the root. - // The latter is needed for statScore and killer initialization. - Stack stack[MAX_PLY+10], *ss = stack+7; - Move pv[MAX_PLY+1]; - Value alpha, beta, delta; - Move lastBestMove = MOVE_NONE; - Depth lastBestMoveDepth = 0; - MainThread* mainThread = (this == Threads.main() ? Threads.main() : nullptr); - double timeReduction = 1, totBestMoveChanges = 0; - Color us = rootPos.side_to_move(); - int iterIdx = 0; - - std::memset(ss-7, 0, 10 * sizeof(Stack)); - for (int i = 7; i > 0; --i) - { - (ss-i)->continuationHistory = &this->continuationHistory[0][0][NO_PIECE][0]; // Use as a sentinel - (ss-i)->staticEval = VALUE_NONE; - } - - for (int i = 0; i <= MAX_PLY + 2; ++i) - (ss+i)->ply = i; - - ss->pv = pv; - - bestValue = -VALUE_INFINITE; - - if (mainThread) - { - if (mainThread->bestPreviousScore == VALUE_INFINITE) - for (int i = 0; i < 4; ++i) - mainThread->iterValue[i] = VALUE_ZERO; - else - for (int i = 0; i < 4; ++i) - mainThread->iterValue[i] = mainThread->bestPreviousScore; - } - - size_t multiPV = size_t(Options["MultiPV"]); - Skill skill(Options["Skill Level"], Options["UCI_LimitStrength"] ? int(Options["UCI_Elo"]) : 0); - - // When playing with strength handicap enable MultiPV search that we will - // use behind-the-scenes to retrieve a set of possible moves. - if (skill.enabled()) - multiPV = std::max(multiPV, (size_t)4); - - multiPV = std::min(multiPV, rootMoves.size()); - - int searchAgainCounter = 0; - - // Iterative deepening loop until requested to stop or the target depth is reached - while ( ++rootDepth < MAX_PLY - && !Threads.stop - && !(Limits.depth && mainThread && rootDepth > Limits.depth)) - { - // Age out PV variability metric - if (mainThread) - totBestMoveChanges /= 2; - - // Save the last iteration's scores before the first PV line is searched and - // all the move scores except the (new) PV are set to -VALUE_INFINITE. - for (RootMove& rm : rootMoves) - rm.previousScore = rm.score; - - size_t pvFirst = 0; - pvLast = 0; - - if (!Threads.increaseDepth) - searchAgainCounter++; - - // MultiPV loop. We perform a full root search for each PV line - for (pvIdx = 0; pvIdx < multiPV && !Threads.stop; ++pvIdx) - { - if (pvIdx == pvLast) - { - pvFirst = pvLast; - for (pvLast++; pvLast < rootMoves.size(); pvLast++) - if (rootMoves[pvLast].tbRank != rootMoves[pvFirst].tbRank) - break; - } - - // Reset UCI info selDepth for each depth and each PV line - selDepth = 0; - - // Reset aspiration window starting size - Value prev = rootMoves[pvIdx].averageScore; - delta = Value(10) + int(prev) * prev / 15799; - alpha = std::max(prev - delta,-VALUE_INFINITE); - beta = std::min(prev + delta, VALUE_INFINITE); - - // Adjust optimism based on root move's previousScore - int opt = 109 * prev / (std::abs(prev) + 141); - optimism[ us] = Value(opt); - optimism[~us] = -optimism[us]; - - // Start with a small aspiration window and, in the case of a fail - // high/low, re-search with a bigger window until we don't fail - // high/low anymore. - int failedHighCnt = 0; - while (true) - { - // Adjust the effective depth searched, but ensure at least one effective increment for every - // four searchAgain steps (see issue #2717). - Depth adjustedDepth = std::max(1, rootDepth - failedHighCnt - 3 * (searchAgainCounter + 1) / 4); - bestValue = Stockfish::search(rootPos, ss, alpha, beta, adjustedDepth, false); - - // Bring the best move to the front. It is critical that sorting - // is done with a stable algorithm because all the values but the - // first and eventually the new best one is set to -VALUE_INFINITE - // and we want to keep the same order for all the moves except the - // new PV that goes to the front. Note that in the case of MultiPV - // search the already searched PV lines are preserved. - std::stable_sort(rootMoves.begin() + pvIdx, rootMoves.begin() + pvLast); - - // If search has been stopped, we break immediately. Sorting is - // safe because RootMoves is still valid, although it refers to - // the previous iteration. - if (Threads.stop) - break; - - // When failing high/low give some update (without cluttering - // the UI) before a re-search. - if ( mainThread - && multiPV == 1 - && (bestValue <= alpha || bestValue >= beta) - && Time.elapsed() > 3000) - sync_cout << UCI::pv(rootPos, rootDepth) << sync_endl; - - // In case of failing low/high increase aspiration window and - // re-search, otherwise exit the loop. - if (bestValue <= alpha) - { - beta = (alpha + beta) / 2; - alpha = std::max(bestValue - delta, -VALUE_INFINITE); - - failedHighCnt = 0; - if (mainThread) - mainThread->stopOnPonderhit = false; - } - else if (bestValue >= beta) - { - beta = std::min(bestValue + delta, VALUE_INFINITE); - ++failedHighCnt; - } - else - break; - - delta += delta / 3; - - assert(alpha >= -VALUE_INFINITE && beta <= VALUE_INFINITE); - } - - // Sort the PV lines searched so far and update the GUI - std::stable_sort(rootMoves.begin() + pvFirst, rootMoves.begin() + pvIdx + 1); - - if ( mainThread - && (Threads.stop || pvIdx + 1 == multiPV || Time.elapsed() > 3000)) - sync_cout << UCI::pv(rootPos, rootDepth) << sync_endl; - } - - if (!Threads.stop) - completedDepth = rootDepth; - - if (rootMoves[0].pv[0] != lastBestMove) - { - lastBestMove = rootMoves[0].pv[0]; - lastBestMoveDepth = rootDepth; - } - - // Have we found a "mate in x"? - if ( Limits.mate - && bestValue >= VALUE_MATE_IN_MAX_PLY - && VALUE_MATE - bestValue <= 2 * Limits.mate) - Threads.stop = true; - - if (!mainThread) - continue; - - // If the skill level is enabled and time is up, pick a sub-optimal best move - if (skill.enabled() && skill.time_to_pick(rootDepth)) - skill.pick_best(multiPV); - - // Use part of the gained time from a previous stable move for the current move - for (Thread* th : Threads) - { - totBestMoveChanges += th->bestMoveChanges; - th->bestMoveChanges = 0; - } - - // Do we have time for the next iteration? Can we stop searching now? - if ( Limits.use_time_management() - && !Threads.stop - && !mainThread->stopOnPonderhit) - { - double fallingEval = (69 + 13 * (mainThread->bestPreviousAverageScore - bestValue) - + 6 * (mainThread->iterValue[iterIdx] - bestValue)) / 619.6; - fallingEval = std::clamp(fallingEval, 0.5, 1.5); - - // If the bestMove is stable over several iterations, reduce time accordingly - timeReduction = lastBestMoveDepth + 8 < completedDepth ? 1.57 : 0.65; - double reduction = (1.4 + mainThread->previousTimeReduction) / (2.08 * timeReduction); - double bestMoveInstability = 1 + 1.8 * totBestMoveChanges / Threads.size(); - - double totalTime = Time.optimum() * fallingEval * reduction * bestMoveInstability; - - // Cap used time in case of a single legal move for a better viewer experience in tournaments - // yielding correct scores and sufficiently fast moves. - if (rootMoves.size() == 1) - totalTime = std::min(500.0, totalTime); - - // Stop the search if we have exceeded the totalTime - if (Time.elapsed() > totalTime) - { - // If we are allowed to ponder do not stop the search now but - // keep pondering until the GUI sends "ponderhit" or "stop". - if (mainThread->ponder) - mainThread->stopOnPonderhit = true; - else - Threads.stop = true; - } - else if ( !mainThread->ponder - && Time.elapsed() > totalTime * 0.50) - Threads.increaseDepth = false; - else - Threads.increaseDepth = true; - } - - mainThread->iterValue[iterIdx] = bestValue; - iterIdx = (iterIdx + 1) & 3; - } - - if (!mainThread) - return; - - mainThread->previousTimeReduction = timeReduction; - - // If the skill level is enabled, swap the best PV line with the sub-optimal one - if (skill.enabled()) - std::swap(rootMoves[0], *std::find(rootMoves.begin(), rootMoves.end(), - skill.best ? skill.best : skill.pick_best(multiPV))); -} - - -namespace { - - // search<>() is the main search function for both PV and non-PV nodes - - template - Value search(Position& pos, Stack* ss, Value alpha, Value beta, Depth depth, bool cutNode) { - - constexpr bool PvNode = nodeType != NonPV; - constexpr bool rootNode = nodeType == Root; - - // Check if we have an upcoming move that draws by repetition, or - // if the opponent had an alternative move earlier to this position. - if ( !rootNode - && alpha < VALUE_DRAW - && pos.has_game_cycle(ss->ply)) - { - alpha = value_draw(pos.this_thread()); - if (alpha >= beta) - return alpha; - } + sync_cout << "bestmove " + << UCI::move(bestThread->rootMoves[0].pv[0], rootPos.is_chess960()); - // Dive into quiescence search when the depth reaches zero - if (depth <= 0) - return qsearch(pos, ss, alpha, beta); - - assert(-VALUE_INFINITE <= alpha && alpha < beta && beta <= VALUE_INFINITE); - assert(PvNode || (alpha == beta - 1)); - assert(0 < depth && depth < MAX_PLY); - assert(!(PvNode && cutNode)); - - Move pv[MAX_PLY+1], capturesSearched[32], quietsSearched[64]; - StateInfo st; - ASSERT_ALIGNED(&st, Eval::NNUE::CacheLineSize); - - TTEntry* tte; - Key posKey; - Move ttMove, move, excludedMove, bestMove; - Depth extension, newDepth; - Value bestValue, value, ttValue, eval, maxValue, probCutBeta; - bool givesCheck, improving, priorCapture, singularQuietLMR; - bool capture, moveCountPruning, ttCapture; - Piece movedPiece; - int moveCount, captureCount, quietCount; - - // Step 1. Initialize node - Thread* thisThread = pos.this_thread(); - ss->inCheck = pos.checkers(); - priorCapture = pos.captured_piece(); - Color us = pos.side_to_move(); - moveCount = captureCount = quietCount = ss->moveCount = 0; - bestValue = -VALUE_INFINITE; - maxValue = VALUE_INFINITE; - - // Check for the available remaining time - if (thisThread == Threads.main()) - static_cast(thisThread)->check_time(); - - // Used to send selDepth info to GUI (selDepth counts from 1, ply from 0) - if (PvNode && thisThread->selDepth < ss->ply + 1) - thisThread->selDepth = ss->ply + 1; - - if (!rootNode) - { - // Step 2. Check for aborted search and immediate draw - if ( Threads.stop.load(std::memory_order_relaxed) - || pos.is_draw(ss->ply) - || ss->ply >= MAX_PLY) - return (ss->ply >= MAX_PLY && !ss->inCheck) ? evaluate(pos) - : value_draw(pos.this_thread()); - - // Step 3. Mate distance pruning. Even if we mate at the next move our score - // would be at best mate_in(ss->ply+1), but if alpha is already bigger because - // a shorter mate was found upward in the tree then there is no need to search - // because we will never beat the current alpha. Same logic but with reversed - // signs apply also in the opposite condition of being mated instead of giving - // mate. In this case, return a fail-high score. - alpha = std::max(mated_in(ss->ply), alpha); - beta = std::min(mate_in(ss->ply+1), beta); - if (alpha >= beta) - return alpha; + if (bestThread->rootMoves[0].pv.size() > 1 || + bestThread->rootMoves[0].extract_ponder_from_tt(rootPos)) + std::cout << " ponder " + << UCI::move(bestThread->rootMoves[0].pv[1], rootPos.is_chess960()); + + std::cout << sync_endl; } - else - thisThread->rootDelta = beta - alpha; - - assert(0 <= ss->ply && ss->ply < MAX_PLY); - - (ss+1)->excludedMove = bestMove = MOVE_NONE; - (ss+2)->killers[0] = (ss+2)->killers[1] = MOVE_NONE; - (ss+2)->cutoffCnt = 0; - ss->doubleExtensions = (ss-1)->doubleExtensions; - Square prevSq = is_ok((ss-1)->currentMove) ? to_sq((ss-1)->currentMove) : SQ_NONE; - ss->statScore = 0; - - // Step 4. Transposition table lookup. - excludedMove = ss->excludedMove; - posKey = pos.key(); - tte = TT.probe(posKey, ss->ttHit); - ttValue = ss->ttHit ? value_from_tt(tte->value(), ss->ply, pos.rule50_count()) : VALUE_NONE; - ttMove = rootNode ? thisThread->rootMoves[thisThread->pvIdx].pv[0] - : ss->ttHit ? tte->move() : MOVE_NONE; - ttCapture = ttMove && pos.capture_stage(ttMove); - - // At this point, if excluded, skip straight to step 6, static eval. However, - // to save indentation, we list the condition in all code between here and there. - if (!excludedMove) - ss->ttPv = PvNode || (ss->ttHit && tte->is_pv()); - - // At non-PV nodes we check for an early TT cutoff - if ( !PvNode - && !excludedMove - && tte->depth() > depth - && ttValue != VALUE_NONE // Possible in case of TT access race or if !ttHit - && (tte->bound() & (ttValue >= beta ? BOUND_LOWER : BOUND_UPPER))) - { - // If ttMove is quiet, update move sorting heuristics on TT hit (~2 Elo) - if (ttMove) - { - if (ttValue >= beta) - { - // Bonus for a quiet ttMove that fails high (~2 Elo) - if (!ttCapture) - update_quiet_stats(pos, ss, ttMove, stat_bonus(depth)); - // Extra penalty for early quiet moves of the previous ply (~0 Elo on STC, ~2 Elo on LTC) - if (prevSq != SQ_NONE && (ss-1)->moveCount <= 2 && !priorCapture) - update_continuation_histories(ss-1, pos.piece_on(prevSq), prevSq, -stat_bonus(depth + 1)); + + /// Thread::search() is the main iterative deepening loop. It calls search() + /// repeatedly with increasing depth until the allocated thinking time has been + /// consumed, the user stops the search, or the maximum search depth is reached. + + void Thread::search() { + + // To allow access to (ss-7) up to (ss+2), the stack must be oversized. + // The former is needed to allow update_continuation_histories(ss-1, ...), + // which accesses its argument at ss-6, also near the root. + // The latter is needed for statScore and killer initialization. + Stack stack[MAX_PLY + 10], *ss = stack + 7; + Move pv[MAX_PLY + 1]; + Value alpha, beta, delta; + Move lastBestMove = MOVE_NONE; + Depth lastBestMoveDepth = 0; + MainThread* mainThread = (this == Threads.main() ? Threads.main() : nullptr); + double timeReduction = 1, totBestMoveChanges = 0; + Color us = rootPos.side_to_move(); + int iterIdx = 0; + + std::memset(ss - 7, 0, 10 * sizeof(Stack)); + for (int i = 7; i > 0; --i) { + (ss - i)->continuationHistory = + &this->continuationHistory[0][0][NO_PIECE][0]; // Use as a sentinel + (ss - i)->staticEval = VALUE_NONE; + } + + for (int i = 0; i <= MAX_PLY + 2; ++i) (ss + i)->ply = i; + + ss->pv = pv; + + bestValue = -VALUE_INFINITE; + + if (mainThread) { + if (mainThread->bestPreviousScore == VALUE_INFINITE) + for (int i = 0; i < 4; ++i) mainThread->iterValue[i] = VALUE_ZERO; + else + for (int i = 0; i < 4; ++i) + mainThread->iterValue[i] = mainThread->bestPreviousScore; + } + + size_t multiPV = size_t(Options["MultiPV"]); + Skill skill(Options["Skill Level"], + Options["UCI_LimitStrength"] ? int(Options["UCI_Elo"]) : 0); + + // When playing with strength handicap enable MultiPV search that we will + // use behind-the-scenes to retrieve a set of possible moves. + if (skill.enabled()) multiPV = std::max(multiPV, (size_t) 4); + + multiPV = std::min(multiPV, rootMoves.size()); + + int searchAgainCounter = 0; + + // Iterative deepening loop until requested to stop or the target depth is reached + while (++rootDepth < MAX_PLY && !Threads.stop && + !(Limits.depth && mainThread && rootDepth > Limits.depth)) { + // Age out PV variability metric + if (mainThread) totBestMoveChanges /= 2; + + // Save the last iteration's scores before the first PV line is searched and + // all the move scores except the (new) PV are set to -VALUE_INFINITE. + for (RootMove& rm : rootMoves) rm.previousScore = rm.score; + + size_t pvFirst = 0; + pvLast = 0; + + if (!Threads.increaseDepth) searchAgainCounter++; + + // MultiPV loop. We perform a full root search for each PV line + for (pvIdx = 0; pvIdx < multiPV && !Threads.stop; ++pvIdx) { + if (pvIdx == pvLast) { + pvFirst = pvLast; + for (pvLast++; pvLast < rootMoves.size(); pvLast++) + if (rootMoves[pvLast].tbRank != rootMoves[pvFirst].tbRank) break; + } + + // Reset UCI info selDepth for each depth and each PV line + selDepth = 0; + + // Reset aspiration window starting size + Value prev = rootMoves[pvIdx].averageScore; + delta = Value(10) + int(prev) * prev / 15799; + alpha = std::max(prev - delta, -VALUE_INFINITE); + beta = std::min(prev + delta, VALUE_INFINITE); + + // Adjust optimism based on root move's previousScore + int opt = 109 * prev / (std::abs(prev) + 141); + optimism[us] = Value(opt); + optimism[~us] = -optimism[us]; + + // Start with a small aspiration window and, in the case of a fail + // high/low, re-search with a bigger window until we don't fail + // high/low anymore. + int failedHighCnt = 0; + while (true) { + // Adjust the effective depth searched, but ensure at least one effective increment for every + // four searchAgain steps (see issue #2717). + Depth adjustedDepth = + std::max(1, rootDepth - failedHighCnt - 3 * (searchAgainCounter + 1) / 4); + bestValue = + Stockfish::search(rootPos, ss, alpha, beta, adjustedDepth, false); + + // Bring the best move to the front. It is critical that sorting + // is done with a stable algorithm because all the values but the + // first and eventually the new best one is set to -VALUE_INFINITE + // and we want to keep the same order for all the moves except the + // new PV that goes to the front. Note that in the case of MultiPV + // search the already searched PV lines are preserved. + std::stable_sort(rootMoves.begin() + pvIdx, rootMoves.begin() + pvLast); + + // If search has been stopped, we break immediately. Sorting is + // safe because RootMoves is still valid, although it refers to + // the previous iteration. + if (Threads.stop) break; + + // When failing high/low give some update (without cluttering + // the UI) before a re-search. + if (mainThread && multiPV == 1 && (bestValue <= alpha || bestValue >= beta) && + Time.elapsed() > 3000) + sync_cout << UCI::pv(rootPos, rootDepth) << sync_endl; + + // In case of failing low/high increase aspiration window and + // re-search, otherwise exit the loop. + if (bestValue <= alpha) { + beta = (alpha + beta) / 2; + alpha = std::max(bestValue - delta, -VALUE_INFINITE); + + failedHighCnt = 0; + if (mainThread) mainThread->stopOnPonderhit = false; + } else if (bestValue >= beta) { + beta = std::min(bestValue + delta, VALUE_INFINITE); + ++failedHighCnt; + } else + break; + + delta += delta / 3; + + assert(alpha >= -VALUE_INFINITE && beta <= VALUE_INFINITE); + } + + // Sort the PV lines searched so far and update the GUI + std::stable_sort(rootMoves.begin() + pvFirst, rootMoves.begin() + pvIdx + 1); + + if (mainThread && (Threads.stop || pvIdx + 1 == multiPV || Time.elapsed() > 3000)) + sync_cout << UCI::pv(rootPos, rootDepth) << sync_endl; } - // Penalty for a quiet ttMove that fails low (~1 Elo) - else if (!ttCapture) - { - int penalty = -stat_bonus(depth); - thisThread->mainHistory[us][from_to(ttMove)] << penalty; - update_continuation_histories(ss, pos.moved_piece(ttMove), to_sq(ttMove), penalty); + + if (!Threads.stop) completedDepth = rootDepth; + + if (rootMoves[0].pv[0] != lastBestMove) { + lastBestMove = rootMoves[0].pv[0]; + lastBestMoveDepth = rootDepth; } + + // Have we found a "mate in x"? + if (Limits.mate && bestValue >= VALUE_MATE_IN_MAX_PLY && + VALUE_MATE - bestValue <= 2 * Limits.mate) + Threads.stop = true; + + if (!mainThread) continue; + + // If the skill level is enabled and time is up, pick a sub-optimal best move + if (skill.enabled() && skill.time_to_pick(rootDepth)) skill.pick_best(multiPV); + + // Use part of the gained time from a previous stable move for the current move + for (Thread* th : Threads) { + totBestMoveChanges += th->bestMoveChanges; + th->bestMoveChanges = 0; + } + + // Do we have time for the next iteration? Can we stop searching now? + if (Limits.use_time_management() && !Threads.stop && !mainThread->stopOnPonderhit) { + double fallingEval = (69 + 13 * (mainThread->bestPreviousAverageScore - bestValue) + + 6 * (mainThread->iterValue[iterIdx] - bestValue)) / + 619.6; + fallingEval = std::clamp(fallingEval, 0.5, 1.5); + + // If the bestMove is stable over several iterations, reduce time accordingly + timeReduction = lastBestMoveDepth + 8 < completedDepth ? 1.57 : 0.65; + double reduction = + (1.4 + mainThread->previousTimeReduction) / (2.08 * timeReduction); + double bestMoveInstability = 1 + 1.8 * totBestMoveChanges / Threads.size(); + + double totalTime = Time.optimum() * fallingEval * reduction * bestMoveInstability; + + // Cap used time in case of a single legal move for a better viewer experience in tournaments + // yielding correct scores and sufficiently fast moves. + if (rootMoves.size() == 1) totalTime = std::min(500.0, totalTime); + + // Stop the search if we have exceeded the totalTime + if (Time.elapsed() > totalTime) { + // If we are allowed to ponder do not stop the search now but + // keep pondering until the GUI sends "ponderhit" or "stop". + if (mainThread->ponder) + mainThread->stopOnPonderhit = true; + else + Threads.stop = true; + } else if (!mainThread->ponder && Time.elapsed() > totalTime * 0.50) + Threads.increaseDepth = false; + else + Threads.increaseDepth = true; + } + + mainThread->iterValue[iterIdx] = bestValue; + iterIdx = (iterIdx + 1) & 3; } - // Partial workaround for the graph history interaction problem - // For high rule50 counts don't produce transposition table cutoffs. - if (pos.rule50_count() < 90) - return ttValue; + if (!mainThread) return; + + mainThread->previousTimeReduction = timeReduction; + + // If the skill level is enabled, swap the best PV line with the sub-optimal one + if (skill.enabled()) + std::swap(rootMoves[0], *std::find(rootMoves.begin(), rootMoves.end(), + skill.best ? skill.best : skill.pick_best(multiPV))); } - // Step 5. Tablebases probe - if (!rootNode && !excludedMove && TB::Cardinality) - { - int piecesCount = pos.count(); - if ( piecesCount <= TB::Cardinality - && (piecesCount < TB::Cardinality || depth >= TB::ProbeDepth) - && pos.rule50_count() == 0 - && !pos.can_castle(ANY_CASTLING)) - { - TB::ProbeState err; - TB::WDLScore wdl = Tablebases::probe_wdl(pos, &err); + namespace { - // Force check of time on the next occasion - if (thisThread == Threads.main()) - static_cast(thisThread)->callsCnt = 0; + // search<>() is the main search function for both PV and non-PV nodes - if (err != TB::ProbeState::FAIL) - { - thisThread->tbHits.fetch_add(1, std::memory_order_relaxed); + template + Value search(Position& pos, Stack* ss, Value alpha, Value beta, Depth depth, bool cutNode) { - int drawScore = TB::UseRule50 ? 1 : 0; + constexpr bool PvNode = nodeType != NonPV; + constexpr bool rootNode = nodeType == Root; - // use the range VALUE_MATE_IN_MAX_PLY to VALUE_TB_WIN_IN_MAX_PLY to score - value = wdl < -drawScore ? VALUE_MATED_IN_MAX_PLY + ss->ply + 1 - : wdl > drawScore ? VALUE_MATE_IN_MAX_PLY - ss->ply - 1 - : VALUE_DRAW + 2 * wdl * drawScore; + // Check if we have an upcoming move that draws by repetition, or + // if the opponent had an alternative move earlier to this position. + if (!rootNode && alpha < VALUE_DRAW && pos.has_game_cycle(ss->ply)) { + alpha = value_draw(pos.this_thread()); + if (alpha >= beta) return alpha; + } - Bound b = wdl < -drawScore ? BOUND_UPPER - : wdl > drawScore ? BOUND_LOWER : BOUND_EXACT; + // Dive into quiescence search when the depth reaches zero + if (depth <= 0) return qsearch < PvNode ? PV : NonPV > (pos, ss, alpha, beta); + + assert(-VALUE_INFINITE <= alpha && alpha < beta && beta <= VALUE_INFINITE); + assert(PvNode || (alpha == beta - 1)); + assert(0 < depth && depth < MAX_PLY); + assert(!(PvNode && cutNode)); + + Move pv[MAX_PLY + 1], capturesSearched[32], quietsSearched[64]; + StateInfo st; + ASSERT_ALIGNED(&st, Eval::NNUE::CacheLineSize); + + TTEntry* tte; + Key posKey; + Move ttMove, move, excludedMove, bestMove; + Depth extension, newDepth; + Value bestValue, value, ttValue, eval, maxValue, probCutBeta; + bool givesCheck, improving, priorCapture, singularQuietLMR; + bool capture, moveCountPruning, ttCapture; + Piece movedPiece; + int moveCount, captureCount, quietCount; + + // Step 1. Initialize node + Thread* thisThread = pos.this_thread(); + ss->inCheck = pos.checkers(); + priorCapture = pos.captured_piece(); + Color us = pos.side_to_move(); + moveCount = captureCount = quietCount = ss->moveCount = 0; + bestValue = -VALUE_INFINITE; + maxValue = VALUE_INFINITE; + + // Check for the available remaining time + if (thisThread == Threads.main()) static_cast(thisThread)->check_time(); + + // Used to send selDepth info to GUI (selDepth counts from 1, ply from 0) + if (PvNode && thisThread->selDepth < ss->ply + 1) thisThread->selDepth = ss->ply + 1; + + if (!rootNode) { + // Step 2. Check for aborted search and immediate draw + if (Threads.stop.load(std::memory_order_relaxed) || pos.is_draw(ss->ply) || + ss->ply >= MAX_PLY) + return (ss->ply >= MAX_PLY && !ss->inCheck) ? evaluate(pos) : + value_draw(pos.this_thread()); + + // Step 3. Mate distance pruning. Even if we mate at the next move our score + // would be at best mate_in(ss->ply+1), but if alpha is already bigger because + // a shorter mate was found upward in the tree then there is no need to search + // because we will never beat the current alpha. Same logic but with reversed + // signs apply also in the opposite condition of being mated instead of giving + // mate. In this case, return a fail-high score. + alpha = std::max(mated_in(ss->ply), alpha); + beta = std::min(mate_in(ss->ply + 1), beta); + if (alpha >= beta) return alpha; + } else + thisThread->rootDelta = beta - alpha; + + assert(0 <= ss->ply && ss->ply < MAX_PLY); + + (ss + 1)->excludedMove = bestMove = MOVE_NONE; + (ss + 2)->killers[0] = (ss + 2)->killers[1] = MOVE_NONE; + (ss + 2)->cutoffCnt = 0; + ss->doubleExtensions = (ss - 1)->doubleExtensions; + Square prevSq = is_ok((ss - 1)->currentMove) ? to_sq((ss - 1)->currentMove) : SQ_NONE; + ss->statScore = 0; + + // Step 4. Transposition table lookup. + excludedMove = ss->excludedMove; + posKey = pos.key(); + tte = TT.probe(posKey, ss->ttHit); + ttValue = + ss->ttHit ? value_from_tt(tte->value(), ss->ply, pos.rule50_count()) : VALUE_NONE; + ttMove = rootNode ? thisThread->rootMoves[thisThread->pvIdx].pv[0] : + ss->ttHit ? tte->move() : + MOVE_NONE; + ttCapture = ttMove && pos.capture_stage(ttMove); + + // At this point, if excluded, skip straight to step 6, static eval. However, + // to save indentation, we list the condition in all code between here and there. + if (!excludedMove) ss->ttPv = PvNode || (ss->ttHit && tte->is_pv()); + + // At non-PV nodes we check for an early TT cutoff + if (!PvNode && !excludedMove && tte->depth() > depth && + ttValue != VALUE_NONE // Possible in case of TT access race or if !ttHit + && (tte->bound() & (ttValue >= beta ? BOUND_LOWER : BOUND_UPPER))) { + // If ttMove is quiet, update move sorting heuristics on TT hit (~2 Elo) + if (ttMove) { + if (ttValue >= beta) { + // Bonus for a quiet ttMove that fails high (~2 Elo) + if (!ttCapture) update_quiet_stats(pos, ss, ttMove, stat_bonus(depth)); + + // Extra penalty for early quiet moves of the previous ply (~0 Elo on STC, ~2 Elo on LTC) + if (prevSq != SQ_NONE && (ss - 1)->moveCount <= 2 && !priorCapture) + update_continuation_histories(ss - 1, pos.piece_on(prevSq), prevSq, + -stat_bonus(depth + 1)); + } + // Penalty for a quiet ttMove that fails low (~1 Elo) + else if (!ttCapture) { + int penalty = -stat_bonus(depth); + thisThread->mainHistory[us][from_to(ttMove)] << penalty; + update_continuation_histories(ss, pos.moved_piece(ttMove), to_sq(ttMove), + penalty); + } + } - if ( b == BOUND_EXACT - || (b == BOUND_LOWER ? value >= beta : value <= alpha)) - { - tte->save(posKey, value_to_tt(value, ss->ply), ss->ttPv, b, - std::min(MAX_PLY - 1, depth + 6), - MOVE_NONE, VALUE_NONE); + // Partial workaround for the graph history interaction problem + // For high rule50 counts don't produce transposition table cutoffs. + if (pos.rule50_count() < 90) return ttValue; + } - return value; + // Step 5. Tablebases probe + if (!rootNode && !excludedMove && TB::Cardinality) { + int piecesCount = pos.count(); + + if (piecesCount <= TB::Cardinality && + (piecesCount < TB::Cardinality || depth >= TB::ProbeDepth) && + pos.rule50_count() == 0 && !pos.can_castle(ANY_CASTLING)) { + TB::ProbeState err; + TB::WDLScore wdl = Tablebases::probe_wdl(pos, &err); + + // Force check of time on the next occasion + if (thisThread == Threads.main()) + static_cast(thisThread)->callsCnt = 0; + + if (err != TB::ProbeState::FAIL) { + thisThread->tbHits.fetch_add(1, std::memory_order_relaxed); + + int drawScore = TB::UseRule50 ? 1 : 0; + + // use the range VALUE_MATE_IN_MAX_PLY to VALUE_TB_WIN_IN_MAX_PLY to score + value = wdl < -drawScore ? VALUE_MATED_IN_MAX_PLY + ss->ply + 1 : + wdl > drawScore ? VALUE_MATE_IN_MAX_PLY - ss->ply - 1 : + VALUE_DRAW + 2 * wdl * drawScore; + + Bound b = wdl < -drawScore ? BOUND_UPPER : + wdl > drawScore ? BOUND_LOWER : + BOUND_EXACT; + + if (b == BOUND_EXACT || + (b == BOUND_LOWER ? value >= beta : value <= alpha)) { + tte->save(posKey, value_to_tt(value, ss->ply), ss->ttPv, b, + std::min(MAX_PLY - 1, depth + 6), MOVE_NONE, VALUE_NONE); + + return value; + } + + if (PvNode) { + if (b == BOUND_LOWER) + bestValue = value, alpha = std::max(alpha, bestValue); + else + maxValue = value; + } + } } + } - if (PvNode) - { - if (b == BOUND_LOWER) - bestValue = value, alpha = std::max(alpha, bestValue); - else - maxValue = value; + CapturePieceToHistory& captureHistory = thisThread->captureHistory; + + // Step 6. Static evaluation of the position + if (ss->inCheck) { + // Skip early pruning when in check + ss->staticEval = eval = VALUE_NONE; + improving = false; + goto moves_loop; + } else if (excludedMove) { + // Providing the hint that this node's accumulator will be used often brings significant Elo gain (13 Elo) + Eval::NNUE::hint_common_parent_position(pos); + eval = ss->staticEval; + } else if (ss->ttHit) { + // Never assume anything about values stored in TT + ss->staticEval = eval = tte->eval(); + if (eval == VALUE_NONE) + ss->staticEval = eval = evaluate(pos); + else if (PvNode) + Eval::NNUE::hint_common_parent_position(pos); + + // ttValue can be used as a better position evaluation (~7 Elo) + if (ttValue != VALUE_NONE && + (tte->bound() & (ttValue > eval ? BOUND_LOWER : BOUND_UPPER))) + eval = ttValue; + } else { + ss->staticEval = eval = evaluate(pos); + // Save static evaluation into the transposition table + tte->save(posKey, VALUE_NONE, ss->ttPv, BOUND_NONE, DEPTH_NONE, MOVE_NONE, eval); + } + + // Use static evaluation difference to improve quiet move ordering (~4 Elo) + if (is_ok((ss - 1)->currentMove) && !(ss - 1)->inCheck && !priorCapture) { + int bonus = + std::clamp(-18 * int((ss - 1)->staticEval + ss->staticEval), -1817, 1817); + thisThread->mainHistory[~us][from_to((ss - 1)->currentMove)] << bonus; + } + + // Set up the improving flag, which is true if current static evaluation is + // bigger than the previous static evaluation at our turn (if we were in + // check at our previous move we look at static evaluation at move prior to it + // and if we were in check at move prior to it flag is set to true) and is + // false otherwise. The improving flag is used in various pruning heuristics. + improving = (ss - 2)->staticEval != VALUE_NONE ? ss->staticEval > (ss - 2)->staticEval : + (ss - 4)->staticEval != VALUE_NONE ? ss->staticEval > (ss - 4)->staticEval : + true; + + // Step 7. Razoring (~1 Elo). + // If eval is really low check with qsearch if it can exceed alpha, if it can't, + // return a fail low. + if (eval < alpha - 456 - 252 * depth * depth) { + value = qsearch(pos, ss, alpha - 1, alpha); + if (value < alpha) return value; + } + + // Step 8. Futility pruning: child node (~40 Elo). + // The depth condition is important for mate finding. + if (!ss->ttPv && depth < 9 && + eval - futility_margin(depth, cutNode && !ss->ttHit, improving) - + (ss - 1)->statScore / 306 >= + beta && + eval >= beta && + eval < 24923) // larger than VALUE_KNOWN_WIN, but smaller than TB wins + return eval; + + // Step 9. Null move search with verification search (~35 Elo) + if (!PvNode && (ss - 1)->currentMove != MOVE_NULL && (ss - 1)->statScore < 17329 && + eval >= beta && eval >= ss->staticEval && + ss->staticEval >= beta - 21 * depth + 258 && !excludedMove && + pos.non_pawn_material(us) && ss->ply >= thisThread->nmpMinPly && + beta > VALUE_TB_LOSS_IN_MAX_PLY) { + assert(eval - beta >= 0); + + // Null move dynamic reduction based on depth and eval + Depth R = std::min(int(eval - beta) / 173, 6) + depth / 3 + 4; + + ss->currentMove = MOVE_NULL; + ss->continuationHistory = &thisThread->continuationHistory[0][0][NO_PIECE][0]; + + pos.do_null_move(st); + + Value nullValue = + -search(pos, ss + 1, -beta, -beta + 1, depth - R, !cutNode); + + pos.undo_null_move(); + + if (nullValue >= beta) { + // Do not return unproven mate or TB scores + nullValue = std::min(nullValue, VALUE_TB_WIN_IN_MAX_PLY - 1); + + if (thisThread->nmpMinPly || depth < 14) return nullValue; + + assert(!thisThread->nmpMinPly); // Recursive verification is not allowed + + // Do verification search at high depths, with null move pruning disabled + // until ply exceeds nmpMinPly. + thisThread->nmpMinPly = ss->ply + 3 * (depth - R) / 4; + + Value v = search(pos, ss, beta - 1, beta, depth - R, false); + + thisThread->nmpMinPly = 0; + + if (v >= beta) return nullValue; } } - } - } - CapturePieceToHistory& captureHistory = thisThread->captureHistory; + // Step 10. If the position doesn't have a ttMove, decrease depth by 2 + // (or by 4 if the TT entry for the current position was hit and the stored depth is greater than or equal to the current depth). + // Use qsearch if depth is equal or below zero (~9 Elo) + if (PvNode && !ttMove) depth -= 2 + 2 * (ss->ttHit && tte->depth() >= depth); - // Step 6. Static evaluation of the position - if (ss->inCheck) - { - // Skip early pruning when in check - ss->staticEval = eval = VALUE_NONE; - improving = false; - goto moves_loop; - } - else if (excludedMove) - { - // Providing the hint that this node's accumulator will be used often brings significant Elo gain (13 Elo) - Eval::NNUE::hint_common_parent_position(pos); - eval = ss->staticEval; - } - else if (ss->ttHit) - { - // Never assume anything about values stored in TT - ss->staticEval = eval = tte->eval(); - if (eval == VALUE_NONE) - ss->staticEval = eval = evaluate(pos); - else if (PvNode) - Eval::NNUE::hint_common_parent_position(pos); - - // ttValue can be used as a better position evaluation (~7 Elo) - if ( ttValue != VALUE_NONE - && (tte->bound() & (ttValue > eval ? BOUND_LOWER : BOUND_UPPER))) - eval = ttValue; - } - else - { - ss->staticEval = eval = evaluate(pos); - // Save static evaluation into the transposition table - tte->save(posKey, VALUE_NONE, ss->ttPv, BOUND_NONE, DEPTH_NONE, MOVE_NONE, eval); - } + if (depth <= 0) return qsearch(pos, ss, alpha, beta); - // Use static evaluation difference to improve quiet move ordering (~4 Elo) - if (is_ok((ss-1)->currentMove) && !(ss-1)->inCheck && !priorCapture) - { - int bonus = std::clamp(-18 * int((ss-1)->staticEval + ss->staticEval), -1817, 1817); - thisThread->mainHistory[~us][from_to((ss-1)->currentMove)] << bonus; - } + if (cutNode && depth >= 8 && !ttMove) depth -= 2; - // Set up the improving flag, which is true if current static evaluation is - // bigger than the previous static evaluation at our turn (if we were in - // check at our previous move we look at static evaluation at move prior to it - // and if we were in check at move prior to it flag is set to true) and is - // false otherwise. The improving flag is used in various pruning heuristics. - improving = (ss-2)->staticEval != VALUE_NONE ? ss->staticEval > (ss-2)->staticEval - : (ss-4)->staticEval != VALUE_NONE ? ss->staticEval > (ss-4)->staticEval - : true; - - // Step 7. Razoring (~1 Elo). - // If eval is really low check with qsearch if it can exceed alpha, if it can't, - // return a fail low. - if (eval < alpha - 456 - 252 * depth * depth) - { - value = qsearch(pos, ss, alpha - 1, alpha); - if (value < alpha) - return value; - } + probCutBeta = beta + 168 - 61 * improving; - // Step 8. Futility pruning: child node (~40 Elo). - // The depth condition is important for mate finding. - if ( !ss->ttPv - && depth < 9 - && eval - futility_margin(depth, cutNode && !ss->ttHit, improving) - (ss-1)->statScore / 306 >= beta - && eval >= beta - && eval < 24923) // larger than VALUE_KNOWN_WIN, but smaller than TB wins - return eval; + // Step 11. ProbCut (~10 Elo) + // If we have a good enough capture (or queen promotion) and a reduced search returns a value + // much above beta, we can (almost) safely prune the previous move. + if ( + !PvNode && depth > 3 && + abs(beta) < VALUE_TB_WIN_IN_MAX_PLY + // If value from transposition table is lower than probCutBeta, don't attempt probCut + // there and in further interactions with transposition table cutoff depth is set to depth - 3 + // because probCut search has depth set to depth - 4 but we also do a move before it + // So effective depth is equal to depth - 3 + && !(tte->depth() >= depth - 3 && ttValue != VALUE_NONE && ttValue < probCutBeta)) { + assert(probCutBeta < VALUE_INFINITE); - // Step 9. Null move search with verification search (~35 Elo) - if ( !PvNode - && (ss-1)->currentMove != MOVE_NULL - && (ss-1)->statScore < 17329 - && eval >= beta - && eval >= ss->staticEval - && ss->staticEval >= beta - 21 * depth + 258 - && !excludedMove - && pos.non_pawn_material(us) - && ss->ply >= thisThread->nmpMinPly - && beta > VALUE_TB_LOSS_IN_MAX_PLY) - { - assert(eval - beta >= 0); + MovePicker mp(pos, ttMove, probCutBeta - ss->staticEval, &captureHistory); - // Null move dynamic reduction based on depth and eval - Depth R = std::min(int(eval - beta) / 173, 6) + depth / 3 + 4; + while ((move = mp.next_move()) != MOVE_NONE) + if (move != excludedMove && pos.legal(move)) { + assert(pos.capture_stage(move)); - ss->currentMove = MOVE_NULL; - ss->continuationHistory = &thisThread->continuationHistory[0][0][NO_PIECE][0]; + ss->currentMove = move; + ss->continuationHistory = + &thisThread->continuationHistory[ss->inCheck][true][pos.moved_piece(move)] + [to_sq(move)]; - pos.do_null_move(st); + pos.do_move(move, st); - Value nullValue = -search(pos, ss+1, -beta, -beta+1, depth-R, !cutNode); + // Perform a preliminary qsearch to verify that the move holds + value = -qsearch(pos, ss + 1, -probCutBeta, -probCutBeta + 1); - pos.undo_null_move(); + // If the qsearch held, perform the regular search + if (value >= probCutBeta) + value = -search(pos, ss + 1, -probCutBeta, -probCutBeta + 1, + depth - 4, !cutNode); - if (nullValue >= beta) - { - // Do not return unproven mate or TB scores - nullValue = std::min(nullValue, VALUE_TB_WIN_IN_MAX_PLY-1); + pos.undo_move(move); - if (thisThread->nmpMinPly || depth < 14) - return nullValue; + if (value >= probCutBeta) { + // Save ProbCut data into transposition table + tte->save(posKey, value_to_tt(value, ss->ply), ss->ttPv, BOUND_LOWER, + depth - 3, move, ss->staticEval); + return value; + } + } - assert(!thisThread->nmpMinPly); // Recursive verification is not allowed + Eval::NNUE::hint_common_parent_position(pos); + } - // Do verification search at high depths, with null move pruning disabled - // until ply exceeds nmpMinPly. - thisThread->nmpMinPly = ss->ply + 3 * (depth-R) / 4; +moves_loop: // When in check, search starts here + + // Step 12. A small Probcut idea, when we are in check (~4 Elo) + probCutBeta = beta + 413; + if (ss->inCheck && !PvNode && ttCapture && (tte->bound() & BOUND_LOWER) && + tte->depth() >= depth - 4 && ttValue >= probCutBeta && + abs(ttValue) <= VALUE_KNOWN_WIN && abs(beta) <= VALUE_KNOWN_WIN) + return probCutBeta; + + const PieceToHistory* contHist[] = {(ss - 1)->continuationHistory, + (ss - 2)->continuationHistory, + nullptr, + (ss - 4)->continuationHistory, + nullptr, + (ss - 6)->continuationHistory}; + + Move countermove = prevSq != SQ_NONE ? + thisThread->counterMoves[pos.piece_on(prevSq)][prevSq] : + MOVE_NONE; + + MovePicker mp(pos, ttMove, depth, &thisThread->mainHistory, &captureHistory, contHist, + countermove, ss->killers); + + value = bestValue; + moveCountPruning = singularQuietLMR = false; + + // Indicate PvNodes that will probably fail low if the node was searched + // at a depth equal to or greater than the current depth, and the result + // of this search was a fail low. + bool likelyFailLow = + PvNode && ttMove && (tte->bound() & BOUND_UPPER) && tte->depth() >= depth; + + // Step 13. Loop through all pseudo-legal moves until no moves remain + // or a beta cutoff occurs. + while ((move = mp.next_move(moveCountPruning)) != MOVE_NONE) { + assert(is_ok(move)); + + if (move == excludedMove) continue; + + // At root obey the "searchmoves" option and skip moves not listed in Root + // Move List. As a consequence, any illegal move is also skipped. In MultiPV + // mode we also skip PV moves that have been already searched and those + // of lower "TB rank" if we are in a TB root position. + if (rootNode && + !std::count(thisThread->rootMoves.begin() + thisThread->pvIdx, + thisThread->rootMoves.begin() + thisThread->pvLast, move)) + continue; - Value v = search(pos, ss, beta-1, beta, depth-R, false); + // Check for legality + if (!rootNode && !pos.legal(move)) continue; - thisThread->nmpMinPly = 0; + ss->moveCount = ++moveCount; - if (v >= beta) - return nullValue; - } - } + if (rootNode && thisThread == Threads.main() && Time.elapsed() > 3000) + sync_cout << "info depth " << depth << " currmove " + << UCI::move(move, pos.is_chess960()) << " currmovenumber " + << moveCount + thisThread->pvIdx << sync_endl; + if (PvNode) (ss + 1)->pv = nullptr; - // Step 10. If the position doesn't have a ttMove, decrease depth by 2 - // (or by 4 if the TT entry for the current position was hit and the stored depth is greater than or equal to the current depth). - // Use qsearch if depth is equal or below zero (~9 Elo) - if ( PvNode - && !ttMove) - depth -= 2 + 2 * (ss->ttHit && tte->depth() >= depth); - - if (depth <= 0) - return qsearch(pos, ss, alpha, beta); - - if ( cutNode - && depth >= 8 - && !ttMove) - depth -= 2; - - probCutBeta = beta + 168 - 61 * improving; - - // Step 11. ProbCut (~10 Elo) - // If we have a good enough capture (or queen promotion) and a reduced search returns a value - // much above beta, we can (almost) safely prune the previous move. - if ( !PvNode - && depth > 3 - && abs(beta) < VALUE_TB_WIN_IN_MAX_PLY - // If value from transposition table is lower than probCutBeta, don't attempt probCut - // there and in further interactions with transposition table cutoff depth is set to depth - 3 - // because probCut search has depth set to depth - 4 but we also do a move before it - // So effective depth is equal to depth - 3 - && !( tte->depth() >= depth - 3 - && ttValue != VALUE_NONE - && ttValue < probCutBeta)) - { - assert(probCutBeta < VALUE_INFINITE); - - MovePicker mp(pos, ttMove, probCutBeta - ss->staticEval, &captureHistory); - - while ((move = mp.next_move()) != MOVE_NONE) - if (move != excludedMove && pos.legal(move)) - { - assert(pos.capture_stage(move)); + extension = 0; + capture = pos.capture_stage(move); + movedPiece = pos.moved_piece(move); + givesCheck = pos.gives_check(move); + + // Calculate new depth for this move + newDepth = depth - 1; + + Value delta = beta - alpha; + + Depth r = reduction(improving, depth, moveCount, delta, thisThread->rootDelta); + + // Step 14. Pruning at shallow depth (~120 Elo). Depth conditions are important for mate finding. + if (!rootNode && pos.non_pawn_material(us) && + bestValue > VALUE_TB_LOSS_IN_MAX_PLY) { + // Skip quiet moves if movecount exceeds our FutilityMoveCount threshold (~8 Elo) + moveCountPruning = moveCount >= futility_move_count(improving, depth); + + // Reduced depth of the next LMR search + int lmrDepth = newDepth - r; + + if (capture || givesCheck) { + // Futility pruning for captures (~2 Elo) + if (!givesCheck && lmrDepth < 7 && !ss->inCheck && + ss->staticEval + 197 + 248 * lmrDepth + + PieceValue[pos.piece_on(to_sq(move))] + + captureHistory[movedPiece][to_sq(move)] + [type_of(pos.piece_on(to_sq(move)))] / + 7 < + alpha) + continue; + + // SEE based pruning for captures and checks (~11 Elo) + if (!pos.see_ge(move, Value(-205) * depth)) continue; + } else { + int history = (*contHist[0])[movedPiece][to_sq(move)] + + (*contHist[1])[movedPiece][to_sq(move)] + + (*contHist[3])[movedPiece][to_sq(move)]; + + // Continuation history based pruning (~2 Elo) + if (lmrDepth < 6 && history < -3832 * depth) continue; + + history += 2 * thisThread->mainHistory[us][from_to(move)]; + + lmrDepth += history / 7011; + lmrDepth = std::max(lmrDepth, -2); + + // Futility pruning: parent node (~13 Elo) + if (!ss->inCheck && lmrDepth < 12 && + ss->staticEval + 112 + 138 * lmrDepth <= alpha) + continue; + + lmrDepth = std::max(lmrDepth, 0); + + // Prune moves with negative SEE (~4 Elo) + if (!pos.see_ge(move, Value(-31 * lmrDepth * lmrDepth))) continue; + } + } + // Step 15. Extensions (~100 Elo) + // We take care to not overdo to avoid search getting stuck. + if (ss->ply < thisThread->rootDepth * 2) { + // Singular extension search (~94 Elo). If all moves but one fail low on a + // search of (alpha-s, beta-s), and just one fails high on (alpha, beta), + // then that move is singular and should be extended. To verify this we do + // a reduced search on all the other moves but the ttMove and if the result + // is lower than ttValue minus a margin, then we will extend the ttMove. Note + // that depth margin and singularBeta margin are known for having non-linear + // scaling. Their values are optimized to time controls of 180+1.8 and longer + // so changing them requires tests at this type of time controls. + if (!rootNode && + depth >= + 4 - (thisThread->completedDepth > 22) + 2 * (PvNode && tte->is_pv()) && + move == ttMove && + !excludedMove // Avoid recursive singular search + /* && ttValue != VALUE_NONE Already implicit in the next condition */ + && abs(ttValue) < VALUE_KNOWN_WIN && (tte->bound() & BOUND_LOWER) && + tte->depth() >= depth - 3) { + Value singularBeta = + ttValue - (82 + 65 * (ss->ttPv && !PvNode)) * depth / 64; + Depth singularDepth = (depth - 1) / 2; + + ss->excludedMove = move; + value = search(pos, ss, singularBeta - 1, singularBeta, + singularDepth, cutNode); + ss->excludedMove = MOVE_NONE; + + if (value < singularBeta) { + extension = 1; + singularQuietLMR = !ttCapture; + + // Avoid search explosion by limiting the number of double extensions + if (!PvNode && value < singularBeta - 21 && + ss->doubleExtensions <= 11) { + extension = 2; + depth += depth < 13; + } + } + + // Multi-cut pruning + // Our ttMove is assumed to fail high, and now we failed high also on a + // reduced search without the ttMove. So we assume this expected cut-node + // is not singular, that multiple moves fail high, and we can prune the + // whole subtree by returning a softbound. + else if (singularBeta >= beta) + return singularBeta; + + // If the eval of ttMove is greater than beta, we reduce it (negative extension) (~7 Elo) + else if (ttValue >= beta) + extension = -2 - !PvNode; + + // If we are on a cutNode, reduce it based on depth (negative extension) (~1 Elo) + else if (cutNode) + extension = depth < 17 ? -3 : -1; + + // If the eval of ttMove is less than value, we reduce it (negative extension) (~1 Elo) + else if (ttValue <= value) + extension = -1; + } + + // Check extensions (~1 Elo) + else if (givesCheck && depth > 9) + extension = 1; + + // Quiet ttMove extensions (~1 Elo) + else if (PvNode && move == ttMove && move == ss->killers[0] && + (*contHist[0])[movedPiece][to_sq(move)] >= 5168) + extension = 1; + } + + // Add extension to new depth + newDepth += extension; + ss->doubleExtensions = (ss - 1)->doubleExtensions + (extension == 2); + + // Speculative prefetch as early as possible + prefetch(TT.first_entry(pos.key_after(move))); + + // Update the current move (this must be done after singular extension search) ss->currentMove = move; - ss->continuationHistory = &thisThread->continuationHistory[ss->inCheck] - [true] - [pos.moved_piece(move)] - [to_sq(move)]; + ss->continuationHistory = + &thisThread->continuationHistory[ss->inCheck][capture][movedPiece][to_sq(move)]; + + // Step 16. Make the move + pos.do_move(move, st, givesCheck); + + // Decrease reduction if position is or has been on the PV and not likely to fail low. (~3 Elo) + // Decrease further on cutNodes. (~1 Elo) + if (ss->ttPv && !likelyFailLow) r -= cutNode && tte->depth() >= depth ? 3 : 2; + + // Decrease reduction if opponent's move count is high (~1 Elo) + if ((ss - 1)->moveCount > 8) r--; + + // Increase reduction for cut nodes (~3 Elo) + if (cutNode) r += 2; + + // Increase reduction if ttMove is a capture (~3 Elo) + if (ttCapture) r++; + + // Decrease reduction for PvNodes (~2 Elo) + if (PvNode) r--; + + // Decrease reduction if ttMove has been singularly extended (~1 Elo) + if (singularQuietLMR) r--; + + // Increase reduction on repetition (~1 Elo) + if (move == (ss - 4)->currentMove && pos.has_repeated()) r += 2; + + // Increase reduction if next ply has a lot of fail high (~5 Elo) + if ((ss + 1)->cutoffCnt > 3) r++; + + // Decrease reduction for first generated move (ttMove) + else if (move == ttMove) + r--; + + ss->statScore = 2 * thisThread->mainHistory[us][from_to(move)] + + (*contHist[0])[movedPiece][to_sq(move)] + + (*contHist[1])[movedPiece][to_sq(move)] + + (*contHist[3])[movedPiece][to_sq(move)] - 4006; - pos.do_move(move, st); + // Decrease/increase reduction for moves with a good/bad history (~25 Elo) + r -= ss->statScore / (11124 + 4740 * (depth > 5 && depth < 22)); - // Perform a preliminary qsearch to verify that the move holds - value = -qsearch(pos, ss+1, -probCutBeta, -probCutBeta+1); + // Step 17. Late moves reduction / extension (LMR, ~117 Elo) + // We use various heuristics for the sons of a node after the first son has + // been searched. In general, we would like to reduce them, but there are many + // cases where we extend a son if it has good chances to be "interesting". + if (depth >= 2 && moveCount > 1 + (PvNode && ss->ply <= 1) && + (!ss->ttPv || !capture || (cutNode && (ss - 1)->moveCount > 1))) { + // In general we want to cap the LMR depth search at newDepth, but when + // reduction is negative, we allow this move a limited search extension + // beyond the first move depth. This may lead to hidden double extensions. + Depth d = std::clamp(newDepth - r, 1, newDepth + 1); - // If the qsearch held, perform the regular search - if (value >= probCutBeta) - value = -search(pos, ss+1, -probCutBeta, -probCutBeta+1, depth - 4, !cutNode); + value = -search(pos, ss + 1, -(alpha + 1), -alpha, d, true); + // Do a full-depth search when reduced LMR search fails high + if (value > alpha && d < newDepth) { + // Adjust full-depth search based on LMR results - if the result + // was good enough search deeper, if it was bad enough search shallower + const bool doDeeperSearch = value > (bestValue + 64 + 11 * (newDepth - d)); + const bool doEvenDeeperSearch = + value > alpha + 711 && ss->doubleExtensions <= 6; + const bool doShallowerSearch = value < bestValue + newDepth; + + ss->doubleExtensions = ss->doubleExtensions + doEvenDeeperSearch; + + newDepth += doDeeperSearch - doShallowerSearch + doEvenDeeperSearch; + + if (newDepth > d) + value = + -search(pos, ss + 1, -(alpha + 1), -alpha, newDepth, !cutNode); + + int bonus = value <= alpha ? -stat_bonus(newDepth) : + value >= beta ? stat_bonus(newDepth) : + 0; + + update_continuation_histories(ss, movedPiece, to_sq(move), bonus); + } + } + + // Step 18. Full-depth search when LMR is skipped. If expected reduction is high, reduce its depth by 1. + else if (!PvNode || moveCount > 1) { + // Increase reduction for cut nodes and not ttMove (~1 Elo) + if (!ttMove && cutNode) r += 2; + + value = -search(pos, ss + 1, -(alpha + 1), -alpha, newDepth - (r > 3), + !cutNode); + } + + // For PV nodes only, do a full PV search on the first move or after a fail high, + // otherwise let the parent node fail low with value <= alpha and try another move. + if (PvNode && (moveCount == 1 || value > alpha)) { + (ss + 1)->pv = pv; + (ss + 1)->pv[0] = MOVE_NONE; + + value = -search(pos, ss + 1, -beta, -alpha, newDepth, false); + } + + // Step 19. Undo move pos.undo_move(move); - if (value >= probCutBeta) - { - // Save ProbCut data into transposition table - tte->save(posKey, value_to_tt(value, ss->ply), ss->ttPv, BOUND_LOWER, depth - 3, move, ss->staticEval); - return value; + assert(value > -VALUE_INFINITE && value < VALUE_INFINITE); + + // Step 20. Check for a new best move + // Finished searching the move. If a stop occurred, the return value of + // the search cannot be trusted, and we return immediately without + // updating best move, PV and TT. + if (Threads.stop.load(std::memory_order_relaxed)) return VALUE_ZERO; + + if (rootNode) { + RootMove& rm = + *std::find(thisThread->rootMoves.begin(), thisThread->rootMoves.end(), move); + + rm.averageScore = rm.averageScore != -VALUE_INFINITE ? + (2 * value + rm.averageScore) / 3 : + value; + + // PV move or new best move? + if (moveCount == 1 || value > alpha) { + rm.score = rm.uciScore = value; + rm.selDepth = thisThread->selDepth; + rm.scoreLowerbound = rm.scoreUpperbound = false; + + if (value >= beta) { + rm.scoreLowerbound = true; + rm.uciScore = beta; + } else if (value <= alpha) { + rm.scoreUpperbound = true; + rm.uciScore = alpha; + } + + rm.pv.resize(1); + + assert((ss + 1)->pv); + + for (Move* m = (ss + 1)->pv; *m != MOVE_NONE; ++m) rm.pv.push_back(*m); + + // We record how often the best move has been changed in each iteration. + // This information is used for time management. In MultiPV mode, + // we must take care to only do this for the first PV line. + if (moveCount > 1 && !thisThread->pvIdx) ++thisThread->bestMoveChanges; + } else + // All other moves but the PV, are set to the lowest value: this + // is not a problem when sorting because the sort is stable and the + // move position in the list is preserved - just the PV is pushed up. + rm.score = -VALUE_INFINITE; } - } - Eval::NNUE::hint_common_parent_position(pos); - } + if (value > bestValue) { + bestValue = value; + + if (value > alpha) { + bestMove = move; + + if (PvNode && !rootNode) // Update pv even in fail-high case + update_pv(ss->pv, move, (ss + 1)->pv); + + if (value >= beta) { + ss->cutoffCnt += 1 + !ttMove; + assert(value >= beta); // Fail high + break; + } else { + // Reduce other moves if we have found at least one score improvement (~2 Elo) + if (depth > 2 && depth < 12 && beta < 14362 && value > -12393) + depth -= 2; + + assert(depth > 0); + alpha = value; // Update alpha! Always alpha < beta + } + } + } -moves_loop: // When in check, search starts here - - // Step 12. A small Probcut idea, when we are in check (~4 Elo) - probCutBeta = beta + 413; - if ( ss->inCheck - && !PvNode - && ttCapture - && (tte->bound() & BOUND_LOWER) - && tte->depth() >= depth - 4 - && ttValue >= probCutBeta - && abs(ttValue) <= VALUE_KNOWN_WIN - && abs(beta) <= VALUE_KNOWN_WIN) - return probCutBeta; - - const PieceToHistory* contHist[] = { (ss-1)->continuationHistory, (ss-2)->continuationHistory, - nullptr , (ss-4)->continuationHistory, - nullptr , (ss-6)->continuationHistory }; - - Move countermove = prevSq != SQ_NONE ? thisThread->counterMoves[pos.piece_on(prevSq)][prevSq] : MOVE_NONE; - - MovePicker mp(pos, ttMove, depth, &thisThread->mainHistory, - &captureHistory, - contHist, - countermove, - ss->killers); - - value = bestValue; - moveCountPruning = singularQuietLMR = false; - - // Indicate PvNodes that will probably fail low if the node was searched - // at a depth equal to or greater than the current depth, and the result - // of this search was a fail low. - bool likelyFailLow = PvNode - && ttMove - && (tte->bound() & BOUND_UPPER) - && tte->depth() >= depth; - - // Step 13. Loop through all pseudo-legal moves until no moves remain - // or a beta cutoff occurs. - while ((move = mp.next_move(moveCountPruning)) != MOVE_NONE) - { - assert(is_ok(move)); - - if (move == excludedMove) - continue; - - // At root obey the "searchmoves" option and skip moves not listed in Root - // Move List. As a consequence, any illegal move is also skipped. In MultiPV - // mode we also skip PV moves that have been already searched and those - // of lower "TB rank" if we are in a TB root position. - if (rootNode && !std::count(thisThread->rootMoves.begin() + thisThread->pvIdx, - thisThread->rootMoves.begin() + thisThread->pvLast, move)) - continue; - - // Check for legality - if (!rootNode && !pos.legal(move)) - continue; - - ss->moveCount = ++moveCount; - - if (rootNode && thisThread == Threads.main() && Time.elapsed() > 3000) - sync_cout << "info depth " << depth - << " currmove " << UCI::move(move, pos.is_chess960()) - << " currmovenumber " << moveCount + thisThread->pvIdx << sync_endl; - if (PvNode) - (ss+1)->pv = nullptr; - - extension = 0; - capture = pos.capture_stage(move); - movedPiece = pos.moved_piece(move); - givesCheck = pos.gives_check(move); - - // Calculate new depth for this move - newDepth = depth - 1; - - Value delta = beta - alpha; - - Depth r = reduction(improving, depth, moveCount, delta, thisThread->rootDelta); - - // Step 14. Pruning at shallow depth (~120 Elo). Depth conditions are important for mate finding. - if ( !rootNode - && pos.non_pawn_material(us) - && bestValue > VALUE_TB_LOSS_IN_MAX_PLY) - { - // Skip quiet moves if movecount exceeds our FutilityMoveCount threshold (~8 Elo) - moveCountPruning = moveCount >= futility_move_count(improving, depth); - - // Reduced depth of the next LMR search - int lmrDepth = newDepth - r; - - if ( capture - || givesCheck) - { - // Futility pruning for captures (~2 Elo) - if ( !givesCheck - && lmrDepth < 7 - && !ss->inCheck - && ss->staticEval + 197 + 248 * lmrDepth + PieceValue[pos.piece_on(to_sq(move))] - + captureHistory[movedPiece][to_sq(move)][type_of(pos.piece_on(to_sq(move)))] / 7 < alpha) - continue; - - // SEE based pruning for captures and checks (~11 Elo) - if (!pos.see_ge(move, Value(-205) * depth)) - continue; - } - else - { - int history = (*contHist[0])[movedPiece][to_sq(move)] - + (*contHist[1])[movedPiece][to_sq(move)] - + (*contHist[3])[movedPiece][to_sq(move)]; - - // Continuation history based pruning (~2 Elo) - if ( lmrDepth < 6 - && history < -3832 * depth) - continue; - - history += 2 * thisThread->mainHistory[us][from_to(move)]; - - lmrDepth += history / 7011; - lmrDepth = std::max(lmrDepth, -2); - - // Futility pruning: parent node (~13 Elo) - if ( !ss->inCheck - && lmrDepth < 12 - && ss->staticEval + 112 + 138 * lmrDepth <= alpha) - continue; - - lmrDepth = std::max(lmrDepth, 0); - - // Prune moves with negative SEE (~4 Elo) - if (!pos.see_ge(move, Value(-31 * lmrDepth * lmrDepth))) - continue; - } - } - - // Step 15. Extensions (~100 Elo) - // We take care to not overdo to avoid search getting stuck. - if (ss->ply < thisThread->rootDepth * 2) - { - // Singular extension search (~94 Elo). If all moves but one fail low on a - // search of (alpha-s, beta-s), and just one fails high on (alpha, beta), - // then that move is singular and should be extended. To verify this we do - // a reduced search on all the other moves but the ttMove and if the result - // is lower than ttValue minus a margin, then we will extend the ttMove. Note - // that depth margin and singularBeta margin are known for having non-linear - // scaling. Their values are optimized to time controls of 180+1.8 and longer - // so changing them requires tests at this type of time controls. - if ( !rootNode - && depth >= 4 - (thisThread->completedDepth > 22) + 2 * (PvNode && tte->is_pv()) - && move == ttMove - && !excludedMove // Avoid recursive singular search - /* && ttValue != VALUE_NONE Already implicit in the next condition */ - && abs(ttValue) < VALUE_KNOWN_WIN - && (tte->bound() & BOUND_LOWER) - && tte->depth() >= depth - 3) - { - Value singularBeta = ttValue - (82 + 65 * (ss->ttPv && !PvNode)) * depth / 64; - Depth singularDepth = (depth - 1) / 2; - - ss->excludedMove = move; - value = search(pos, ss, singularBeta - 1, singularBeta, singularDepth, cutNode); - ss->excludedMove = MOVE_NONE; - - if (value < singularBeta) - { - extension = 1; - singularQuietLMR = !ttCapture; - - // Avoid search explosion by limiting the number of double extensions - if ( !PvNode - && value < singularBeta - 21 - && ss->doubleExtensions <= 11) - { - extension = 2; - depth += depth < 13; - } - } - - // Multi-cut pruning - // Our ttMove is assumed to fail high, and now we failed high also on a - // reduced search without the ttMove. So we assume this expected cut-node - // is not singular, that multiple moves fail high, and we can prune the - // whole subtree by returning a softbound. - else if (singularBeta >= beta) - return singularBeta; - - // If the eval of ttMove is greater than beta, we reduce it (negative extension) (~7 Elo) - else if (ttValue >= beta) - extension = -2 - !PvNode; - - // If we are on a cutNode, reduce it based on depth (negative extension) (~1 Elo) - else if (cutNode) - extension = depth < 17 ? -3 : -1; - - // If the eval of ttMove is less than value, we reduce it (negative extension) (~1 Elo) - else if (ttValue <= value) - extension = -1; - } - - // Check extensions (~1 Elo) - else if ( givesCheck - && depth > 9) - extension = 1; - - // Quiet ttMove extensions (~1 Elo) - else if ( PvNode - && move == ttMove - && move == ss->killers[0] - && (*contHist[0])[movedPiece][to_sq(move)] >= 5168) - extension = 1; - } - - // Add extension to new depth - newDepth += extension; - ss->doubleExtensions = (ss-1)->doubleExtensions + (extension == 2); - - // Speculative prefetch as early as possible - prefetch(TT.first_entry(pos.key_after(move))); - - // Update the current move (this must be done after singular extension search) - ss->currentMove = move; - ss->continuationHistory = &thisThread->continuationHistory[ss->inCheck] - [capture] - [movedPiece] - [to_sq(move)]; - - // Step 16. Make the move - pos.do_move(move, st, givesCheck); - - // Decrease reduction if position is or has been on the PV and not likely to fail low. (~3 Elo) - // Decrease further on cutNodes. (~1 Elo) - if ( ss->ttPv - && !likelyFailLow) - r -= cutNode && tte->depth() >= depth ? 3 : 2; - - // Decrease reduction if opponent's move count is high (~1 Elo) - if ((ss-1)->moveCount > 8) - r--; - - // Increase reduction for cut nodes (~3 Elo) - if (cutNode) - r += 2; - - // Increase reduction if ttMove is a capture (~3 Elo) - if (ttCapture) - r++; - - // Decrease reduction for PvNodes (~2 Elo) - if (PvNode) - r--; - - // Decrease reduction if ttMove has been singularly extended (~1 Elo) - if (singularQuietLMR) - r--; - - // Increase reduction on repetition (~1 Elo) - if ( move == (ss-4)->currentMove - && pos.has_repeated()) - r += 2; - - // Increase reduction if next ply has a lot of fail high (~5 Elo) - if ((ss+1)->cutoffCnt > 3) - r++; - - // Decrease reduction for first generated move (ttMove) - else if (move == ttMove) - r--; - - ss->statScore = 2 * thisThread->mainHistory[us][from_to(move)] - + (*contHist[0])[movedPiece][to_sq(move)] - + (*contHist[1])[movedPiece][to_sq(move)] - + (*contHist[3])[movedPiece][to_sq(move)] - - 4006; - - // Decrease/increase reduction for moves with a good/bad history (~25 Elo) - r -= ss->statScore / (11124 + 4740 * (depth > 5 && depth < 22)); - - // Step 17. Late moves reduction / extension (LMR, ~117 Elo) - // We use various heuristics for the sons of a node after the first son has - // been searched. In general, we would like to reduce them, but there are many - // cases where we extend a son if it has good chances to be "interesting". - if ( depth >= 2 - && moveCount > 1 + (PvNode && ss->ply <= 1) - && ( !ss->ttPv - || !capture - || (cutNode && (ss-1)->moveCount > 1))) - { - // In general we want to cap the LMR depth search at newDepth, but when - // reduction is negative, we allow this move a limited search extension - // beyond the first move depth. This may lead to hidden double extensions. - Depth d = std::clamp(newDepth - r, 1, newDepth + 1); - - value = -search(pos, ss+1, -(alpha+1), -alpha, d, true); - - // Do a full-depth search when reduced LMR search fails high - if (value > alpha && d < newDepth) - { - // Adjust full-depth search based on LMR results - if the result - // was good enough search deeper, if it was bad enough search shallower - const bool doDeeperSearch = value > (bestValue + 64 + 11 * (newDepth - d)); - const bool doEvenDeeperSearch = value > alpha + 711 && ss->doubleExtensions <= 6; - const bool doShallowerSearch = value < bestValue + newDepth; - - ss->doubleExtensions = ss->doubleExtensions + doEvenDeeperSearch; - - newDepth += doDeeperSearch - doShallowerSearch + doEvenDeeperSearch; - - if (newDepth > d) - value = -search(pos, ss+1, -(alpha+1), -alpha, newDepth, !cutNode); - - int bonus = value <= alpha ? -stat_bonus(newDepth) - : value >= beta ? stat_bonus(newDepth) - : 0; - - update_continuation_histories(ss, movedPiece, to_sq(move), bonus); - } - } - - // Step 18. Full-depth search when LMR is skipped. If expected reduction is high, reduce its depth by 1. - else if (!PvNode || moveCount > 1) - { - // Increase reduction for cut nodes and not ttMove (~1 Elo) - if (!ttMove && cutNode) - r += 2; - - value = -search(pos, ss+1, -(alpha+1), -alpha, newDepth - (r > 3), !cutNode); - } - - // For PV nodes only, do a full PV search on the first move or after a fail high, - // otherwise let the parent node fail low with value <= alpha and try another move. - if (PvNode && (moveCount == 1 || value > alpha)) - { - (ss+1)->pv = pv; - (ss+1)->pv[0] = MOVE_NONE; - - value = -search(pos, ss+1, -beta, -alpha, newDepth, false); - } - - // Step 19. Undo move - pos.undo_move(move); - - assert(value > -VALUE_INFINITE && value < VALUE_INFINITE); - - // Step 20. Check for a new best move - // Finished searching the move. If a stop occurred, the return value of - // the search cannot be trusted, and we return immediately without - // updating best move, PV and TT. - if (Threads.stop.load(std::memory_order_relaxed)) - return VALUE_ZERO; - - if (rootNode) - { - RootMove& rm = *std::find(thisThread->rootMoves.begin(), - thisThread->rootMoves.end(), move); - - rm.averageScore = rm.averageScore != -VALUE_INFINITE ? (2 * value + rm.averageScore) / 3 : value; - - // PV move or new best move? - if (moveCount == 1 || value > alpha) - { - rm.score = rm.uciScore = value; - rm.selDepth = thisThread->selDepth; - rm.scoreLowerbound = rm.scoreUpperbound = false; - - if (value >= beta) - { - rm.scoreLowerbound = true; - rm.uciScore = beta; - } - else if (value <= alpha) - { - rm.scoreUpperbound = true; - rm.uciScore = alpha; - } - - rm.pv.resize(1); - - assert((ss+1)->pv); - - for (Move* m = (ss+1)->pv; *m != MOVE_NONE; ++m) - rm.pv.push_back(*m); - - // We record how often the best move has been changed in each iteration. - // This information is used for time management. In MultiPV mode, - // we must take care to only do this for the first PV line. - if ( moveCount > 1 - && !thisThread->pvIdx) - ++thisThread->bestMoveChanges; - } - else - // All other moves but the PV, are set to the lowest value: this - // is not a problem when sorting because the sort is stable and the - // move position in the list is preserved - just the PV is pushed up. - rm.score = -VALUE_INFINITE; - } - - if (value > bestValue) - { - bestValue = value; - - if (value > alpha) - { - bestMove = move; - - if (PvNode && !rootNode) // Update pv even in fail-high case - update_pv(ss->pv, move, (ss+1)->pv); - - if (value >= beta) - { - ss->cutoffCnt += 1 + !ttMove; - assert(value >= beta); // Fail high - break; - } - else - { - // Reduce other moves if we have found at least one score improvement (~2 Elo) - if ( depth > 2 - && depth < 12 - && beta < 14362 - && value > -12393) - depth -= 2; - - assert(depth > 0); - alpha = value; // Update alpha! Always alpha < beta - } - } - } - - - // If the move is worse than some previously searched move, remember it, to update its stats later - if (move != bestMove) - { - if (capture && captureCount < 32) - capturesSearched[captureCount++] = move; - - else if (!capture && quietCount < 64) - quietsSearched[quietCount++] = move; - } - } - // The following condition would detect a stop only after move loop has been - // completed. But in this case, bestValue is valid because we have fully - // searched our subtree, and we can anyhow save the result in TT. - /* + // If the move is worse than some previously searched move, remember it, to update its stats later + if (move != bestMove) { + if (capture && captureCount < 32) + capturesSearched[captureCount++] = move; + + else if (!capture && quietCount < 64) + quietsSearched[quietCount++] = move; + } + } + + // The following condition would detect a stop only after move loop has been + // completed. But in this case, bestValue is valid because we have fully + // searched our subtree, and we can anyhow save the result in TT. + /* if (Threads.stop) return VALUE_DRAW; */ - // Step 21. Check for mate and stalemate - // All legal moves have been searched and if there are no legal moves, it - // must be a mate or a stalemate. If we are in a singular extension search then - // return a fail low score. - - assert(moveCount || !ss->inCheck || excludedMove || !MoveList(pos).size()); - - if (!moveCount) - bestValue = excludedMove ? alpha : - ss->inCheck ? mated_in(ss->ply) - : VALUE_DRAW; - - // If there is a move that produces search value greater than alpha we update the stats of searched moves - else if (bestMove) - update_all_stats(pos, ss, bestMove, bestValue, beta, prevSq, - quietsSearched, quietCount, capturesSearched, captureCount, depth); - - // Bonus for prior countermove that caused the fail low - else if (!priorCapture && prevSq != SQ_NONE) - { - int bonus = (depth > 5) + (PvNode || cutNode) + (bestValue < alpha - 800) + ((ss-1)->moveCount > 12); - update_continuation_histories(ss-1, pos.piece_on(prevSq), prevSq, stat_bonus(depth) * bonus); - thisThread->mainHistory[~us][from_to((ss-1)->currentMove)] << stat_bonus(depth) * bonus / 2; - } + // Step 21. Check for mate and stalemate + // All legal moves have been searched and if there are no legal moves, it + // must be a mate or a stalemate. If we are in a singular extension search then + // return a fail low score. + + assert(moveCount || !ss->inCheck || excludedMove || !MoveList(pos).size()); + + if (!moveCount) + bestValue = excludedMove ? alpha : ss->inCheck ? mated_in(ss->ply) : VALUE_DRAW; + + // If there is a move that produces search value greater than alpha we update the stats of searched moves + else if (bestMove) + update_all_stats(pos, ss, bestMove, bestValue, beta, prevSq, quietsSearched, + quietCount, capturesSearched, captureCount, depth); + + // Bonus for prior countermove that caused the fail low + else if (!priorCapture && prevSq != SQ_NONE) { + int bonus = (depth > 5) + (PvNode || cutNode) + (bestValue < alpha - 800) + + ((ss - 1)->moveCount > 12); + update_continuation_histories(ss - 1, pos.piece_on(prevSq), prevSq, + stat_bonus(depth) * bonus); + thisThread->mainHistory[~us][from_to((ss - 1)->currentMove)] + << stat_bonus(depth) * bonus / 2; + } - if (PvNode) - bestValue = std::min(bestValue, maxValue); - - // If no good move is found and the previous position was ttPv, then the previous - // opponent move is probably good and the new position is added to the search tree. (~7 Elo) - if (bestValue <= alpha) - ss->ttPv = ss->ttPv || ((ss-1)->ttPv && depth > 3); - - // Write gathered information in transposition table - if (!excludedMove && !(rootNode && thisThread->pvIdx)) - tte->save(posKey, value_to_tt(bestValue, ss->ply), ss->ttPv, - bestValue >= beta ? BOUND_LOWER : - PvNode && bestMove ? BOUND_EXACT : BOUND_UPPER, - depth, bestMove, ss->staticEval); - - assert(bestValue > -VALUE_INFINITE && bestValue < VALUE_INFINITE); - - return bestValue; - } - - - // qsearch() is the quiescence search function, which is called by the main search - // function with zero depth, or recursively with further decreasing depth per call. - // (~155 Elo) - template - Value qsearch(Position& pos, Stack* ss, Value alpha, Value beta, Depth depth) { - - static_assert(nodeType != Root); - constexpr bool PvNode = nodeType == PV; - - assert(alpha >= -VALUE_INFINITE && alpha < beta && beta <= VALUE_INFINITE); - assert(PvNode || (alpha == beta - 1)); - assert(depth <= 0); - - // Check if we have an upcoming move that draws by repetition, or - // if the opponent had an alternative move earlier to this position. - if ( depth < 0 - && alpha < VALUE_DRAW - && pos.has_game_cycle(ss->ply)) - { - alpha = value_draw(pos.this_thread()); - if (alpha >= beta) - return alpha; - } + if (PvNode) bestValue = std::min(bestValue, maxValue); - Move pv[MAX_PLY+1]; - StateInfo st; - ASSERT_ALIGNED(&st, Eval::NNUE::CacheLineSize); - - TTEntry* tte; - Key posKey; - Move ttMove, move, bestMove; - Depth ttDepth; - Value bestValue, value, ttValue, futilityValue, futilityBase; - bool pvHit, givesCheck, capture; - int moveCount; - - // Step 1. Initialize node - if (PvNode) - { - (ss+1)->pv = pv; - ss->pv[0] = MOVE_NONE; - } + // If no good move is found and the previous position was ttPv, then the previous + // opponent move is probably good and the new position is added to the search tree. (~7 Elo) + if (bestValue <= alpha) ss->ttPv = ss->ttPv || ((ss - 1)->ttPv && depth > 3); - Thread* thisThread = pos.this_thread(); - bestMove = MOVE_NONE; - ss->inCheck = pos.checkers(); - moveCount = 0; - - // Step 2. Check for an immediate draw or maximum ply reached - if ( pos.is_draw(ss->ply) - || ss->ply >= MAX_PLY) - return (ss->ply >= MAX_PLY && !ss->inCheck) ? evaluate(pos) : VALUE_DRAW; - - assert(0 <= ss->ply && ss->ply < MAX_PLY); - - // Decide whether or not to include checks: this fixes also the type of - // TT entry depth that we are going to use. Note that in qsearch we use - // only two types of depth in TT: DEPTH_QS_CHECKS or DEPTH_QS_NO_CHECKS. - ttDepth = ss->inCheck || depth >= DEPTH_QS_CHECKS ? DEPTH_QS_CHECKS - : DEPTH_QS_NO_CHECKS; - - // Step 3. Transposition table lookup - posKey = pos.key(); - tte = TT.probe(posKey, ss->ttHit); - ttValue = ss->ttHit ? value_from_tt(tte->value(), ss->ply, pos.rule50_count()) : VALUE_NONE; - ttMove = ss->ttHit ? tte->move() : MOVE_NONE; - pvHit = ss->ttHit && tte->is_pv(); - - // At non-PV nodes we check for an early TT cutoff - if ( !PvNode - && tte->depth() >= ttDepth - && ttValue != VALUE_NONE // Only in case of TT access race or if !ttHit - && (tte->bound() & (ttValue >= beta ? BOUND_LOWER : BOUND_UPPER))) - return ttValue; - - // Step 4. Static evaluation of the position - if (ss->inCheck) - bestValue = futilityBase = -VALUE_INFINITE; - else - { - if (ss->ttHit) - { - // Never assume anything about values stored in TT - if ((ss->staticEval = bestValue = tte->eval()) == VALUE_NONE) - ss->staticEval = bestValue = evaluate(pos); - - // ttValue can be used as a better position evaluation (~13 Elo) - if ( ttValue != VALUE_NONE - && (tte->bound() & (ttValue > bestValue ? BOUND_LOWER : BOUND_UPPER))) - bestValue = ttValue; - } - else - // In case of null move search use previous static eval with a different sign - ss->staticEval = bestValue = (ss-1)->currentMove != MOVE_NULL ? evaluate(pos) - : -(ss-1)->staticEval; - - // Stand pat. Return immediately if static value is at least beta - if (bestValue >= beta) - { - // Save gathered info in transposition table - if (!ss->ttHit) - tte->save(posKey, value_to_tt(bestValue, ss->ply), false, BOUND_LOWER, - DEPTH_NONE, MOVE_NONE, ss->staticEval); + // Write gathered information in transposition table + if (!excludedMove && !(rootNode && thisThread->pvIdx)) + tte->save(posKey, value_to_tt(bestValue, ss->ply), ss->ttPv, + bestValue >= beta ? BOUND_LOWER : + PvNode && bestMove ? BOUND_EXACT : + BOUND_UPPER, + depth, bestMove, ss->staticEval); + + assert(bestValue > -VALUE_INFINITE && bestValue < VALUE_INFINITE); return bestValue; } - if (bestValue > alpha) - alpha = bestValue; - futilityBase = std::min(ss->staticEval, bestValue) + 200; - } + // qsearch() is the quiescence search function, which is called by the main search + // function with zero depth, or recursively with further decreasing depth per call. + // (~155 Elo) + template + Value qsearch(Position& pos, Stack* ss, Value alpha, Value beta, Depth depth) { - const PieceToHistory* contHist[] = { (ss-1)->continuationHistory, (ss-2)->continuationHistory, - nullptr , (ss-4)->continuationHistory, - nullptr , (ss-6)->continuationHistory }; - - // Initialize a MovePicker object for the current position, and prepare - // to search the moves. Because the depth is <= 0 here, only captures, - // queen promotions, and other checks (only if depth >= DEPTH_QS_CHECKS) - // will be generated. - Square prevSq = is_ok((ss-1)->currentMove) ? to_sq((ss-1)->currentMove) : SQ_NONE; - MovePicker mp(pos, ttMove, depth, &thisThread->mainHistory, - &thisThread->captureHistory, - contHist, - prevSq); - - int quietCheckEvasions = 0; - - // Step 5. Loop through all pseudo-legal moves until no moves remain - // or a beta cutoff occurs. - while ((move = mp.next_move()) != MOVE_NONE) - { - assert(is_ok(move)); - - // Check for legality - if (!pos.legal(move)) - continue; - - givesCheck = pos.gives_check(move); - capture = pos.capture_stage(move); - - moveCount++; - - // Step 6. Pruning. - if (bestValue > VALUE_TB_LOSS_IN_MAX_PLY) - { - // Futility pruning and moveCount pruning (~10 Elo) - if ( !givesCheck - && to_sq(move) != prevSq - && futilityBase > -VALUE_KNOWN_WIN - && type_of(move) != PROMOTION) - { - if (moveCount > 2) - continue; + static_assert(nodeType != Root); + constexpr bool PvNode = nodeType == PV; - futilityValue = futilityBase + PieceValue[pos.piece_on(to_sq(move))]; + assert(alpha >= -VALUE_INFINITE && alpha < beta && beta <= VALUE_INFINITE); + assert(PvNode || (alpha == beta - 1)); + assert(depth <= 0); - // If static eval + value of piece we are going to capture is much lower - // than alpha we can prune this move - if (futilityValue <= alpha) - { - bestValue = std::max(bestValue, futilityValue); - continue; - } + // Check if we have an upcoming move that draws by repetition, or + // if the opponent had an alternative move earlier to this position. + if (depth < 0 && alpha < VALUE_DRAW && pos.has_game_cycle(ss->ply)) { + alpha = value_draw(pos.this_thread()); + if (alpha >= beta) return alpha; + } - // If static eval is much lower than alpha and move is not winning material - // we can prune this move - if (futilityBase <= alpha && !pos.see_ge(move, VALUE_ZERO + 1)) - { - bestValue = std::max(bestValue, futilityBase); - continue; - } + Move pv[MAX_PLY + 1]; + StateInfo st; + ASSERT_ALIGNED(&st, Eval::NNUE::CacheLineSize); + + TTEntry* tte; + Key posKey; + Move ttMove, move, bestMove; + Depth ttDepth; + Value bestValue, value, ttValue, futilityValue, futilityBase; + bool pvHit, givesCheck, capture; + int moveCount; + + // Step 1. Initialize node + if (PvNode) { + (ss + 1)->pv = pv; + ss->pv[0] = MOVE_NONE; + } - // If static exchange evaluation is much worse than what is needed to not - // fall below alpha we can prune this move - if (futilityBase > alpha && !pos.see_ge(move, (alpha - futilityBase) * 4)) - { - bestValue = alpha; - continue; + Thread* thisThread = pos.this_thread(); + bestMove = MOVE_NONE; + ss->inCheck = pos.checkers(); + moveCount = 0; + + // Step 2. Check for an immediate draw or maximum ply reached + if (pos.is_draw(ss->ply) || ss->ply >= MAX_PLY) + return (ss->ply >= MAX_PLY && !ss->inCheck) ? evaluate(pos) : VALUE_DRAW; + + assert(0 <= ss->ply && ss->ply < MAX_PLY); + + // Decide whether or not to include checks: this fixes also the type of + // TT entry depth that we are going to use. Note that in qsearch we use + // only two types of depth in TT: DEPTH_QS_CHECKS or DEPTH_QS_NO_CHECKS. + ttDepth = + ss->inCheck || depth >= DEPTH_QS_CHECKS ? DEPTH_QS_CHECKS : DEPTH_QS_NO_CHECKS; + + // Step 3. Transposition table lookup + posKey = pos.key(); + tte = TT.probe(posKey, ss->ttHit); + ttValue = + ss->ttHit ? value_from_tt(tte->value(), ss->ply, pos.rule50_count()) : VALUE_NONE; + ttMove = ss->ttHit ? tte->move() : MOVE_NONE; + pvHit = ss->ttHit && tte->is_pv(); + + // At non-PV nodes we check for an early TT cutoff + if (!PvNode && tte->depth() >= ttDepth && + ttValue != VALUE_NONE // Only in case of TT access race or if !ttHit + && (tte->bound() & (ttValue >= beta ? BOUND_LOWER : BOUND_UPPER))) + return ttValue; + + // Step 4. Static evaluation of the position + if (ss->inCheck) + bestValue = futilityBase = -VALUE_INFINITE; + else { + if (ss->ttHit) { + // Never assume anything about values stored in TT + if ((ss->staticEval = bestValue = tte->eval()) == VALUE_NONE) + ss->staticEval = bestValue = evaluate(pos); + + // ttValue can be used as a better position evaluation (~13 Elo) + if (ttValue != VALUE_NONE && + (tte->bound() & (ttValue > bestValue ? BOUND_LOWER : BOUND_UPPER))) + bestValue = ttValue; + } else + // In case of null move search use previous static eval with a different sign + ss->staticEval = bestValue = + (ss - 1)->currentMove != MOVE_NULL ? evaluate(pos) : -(ss - 1)->staticEval; + + // Stand pat. Return immediately if static value is at least beta + if (bestValue >= beta) { + // Save gathered info in transposition table + if (!ss->ttHit) + tte->save(posKey, value_to_tt(bestValue, ss->ply), false, BOUND_LOWER, + DEPTH_NONE, MOVE_NONE, ss->staticEval); + + return bestValue; } + + if (bestValue > alpha) alpha = bestValue; + + futilityBase = std::min(ss->staticEval, bestValue) + 200; } - // We prune after the second quiet check evasion move, where being 'in check' is - // implicitly checked through the counter, and being a 'quiet move' apart from - // being a tt move is assumed after an increment because captures are pushed ahead. - if (quietCheckEvasions > 1) - break; - - // Continuation history based pruning (~3 Elo) - if ( !capture - && (*contHist[0])[pos.moved_piece(move)][to_sq(move)] < 0 - && (*contHist[1])[pos.moved_piece(move)][to_sq(move)] < 0) - continue; - - // Do not search moves with bad enough SEE values (~5 Elo) - if (!pos.see_ge(move, Value(-95))) - continue; - } + const PieceToHistory* contHist[] = {(ss - 1)->continuationHistory, + (ss - 2)->continuationHistory, + nullptr, + (ss - 4)->continuationHistory, + nullptr, + (ss - 6)->continuationHistory}; + + // Initialize a MovePicker object for the current position, and prepare + // to search the moves. Because the depth is <= 0 here, only captures, + // queen promotions, and other checks (only if depth >= DEPTH_QS_CHECKS) + // will be generated. + Square prevSq = is_ok((ss - 1)->currentMove) ? to_sq((ss - 1)->currentMove) : SQ_NONE; + MovePicker mp(pos, ttMove, depth, &thisThread->mainHistory, &thisThread->captureHistory, + contHist, prevSq); + + int quietCheckEvasions = 0; + + // Step 5. Loop through all pseudo-legal moves until no moves remain + // or a beta cutoff occurs. + while ((move = mp.next_move()) != MOVE_NONE) { + assert(is_ok(move)); + + // Check for legality + if (!pos.legal(move)) continue; + + givesCheck = pos.gives_check(move); + capture = pos.capture_stage(move); + + moveCount++; + + // Step 6. Pruning. + if (bestValue > VALUE_TB_LOSS_IN_MAX_PLY) { + // Futility pruning and moveCount pruning (~10 Elo) + if (!givesCheck && to_sq(move) != prevSq && futilityBase > -VALUE_KNOWN_WIN && + type_of(move) != PROMOTION) { + if (moveCount > 2) continue; + + futilityValue = futilityBase + PieceValue[pos.piece_on(to_sq(move))]; + + // If static eval + value of piece we are going to capture is much lower + // than alpha we can prune this move + if (futilityValue <= alpha) { + bestValue = std::max(bestValue, futilityValue); + continue; + } + + // If static eval is much lower than alpha and move is not winning material + // we can prune this move + if (futilityBase <= alpha && !pos.see_ge(move, VALUE_ZERO + 1)) { + bestValue = std::max(bestValue, futilityBase); + continue; + } + + // If static exchange evaluation is much worse than what is needed to not + // fall below alpha we can prune this move + if (futilityBase > alpha && !pos.see_ge(move, (alpha - futilityBase) * 4)) { + bestValue = alpha; + continue; + } + } + + // We prune after the second quiet check evasion move, where being 'in check' is + // implicitly checked through the counter, and being a 'quiet move' apart from + // being a tt move is assumed after an increment because captures are pushed ahead. + if (quietCheckEvasions > 1) break; + + // Continuation history based pruning (~3 Elo) + if (!capture && (*contHist[0])[pos.moved_piece(move)][to_sq(move)] < 0 && + (*contHist[1])[pos.moved_piece(move)][to_sq(move)] < 0) + continue; + + // Do not search moves with bad enough SEE values (~5 Elo) + if (!pos.see_ge(move, Value(-95))) continue; + } - // Speculative prefetch as early as possible - prefetch(TT.first_entry(pos.key_after(move))); + // Speculative prefetch as early as possible + prefetch(TT.first_entry(pos.key_after(move))); - // Update the current move - ss->currentMove = move; - ss->continuationHistory = &thisThread->continuationHistory[ss->inCheck] - [capture] - [pos.moved_piece(move)] - [to_sq(move)]; + // Update the current move + ss->currentMove = move; + ss->continuationHistory = + &thisThread->continuationHistory[ss->inCheck][capture][pos.moved_piece(move)] + [to_sq(move)]; - quietCheckEvasions += !capture && ss->inCheck; + quietCheckEvasions += !capture && ss->inCheck; - // Step 7. Make and search the move - pos.do_move(move, st, givesCheck); - value = -qsearch(pos, ss+1, -beta, -alpha, depth - 1); - pos.undo_move(move); + // Step 7. Make and search the move + pos.do_move(move, st, givesCheck); + value = -qsearch(pos, ss + 1, -beta, -alpha, depth - 1); + pos.undo_move(move); - assert(value > -VALUE_INFINITE && value < VALUE_INFINITE); + assert(value > -VALUE_INFINITE && value < VALUE_INFINITE); - // Step 8. Check for a new best move - if (value > bestValue) - { - bestValue = value; + // Step 8. Check for a new best move + if (value > bestValue) { + bestValue = value; - if (value > alpha) - { - bestMove = move; + if (value > alpha) { + bestMove = move; - if (PvNode) // Update pv even in fail-high case - update_pv(ss->pv, move, (ss+1)->pv); + if (PvNode) // Update pv even in fail-high case + update_pv(ss->pv, move, (ss + 1)->pv); - if (value < beta) // Update alpha here! - alpha = value; - else - break; // Fail high + if (value < beta) // Update alpha here! + alpha = value; + else + break; // Fail high + } + } } - } - } - // Step 9. Check for mate - // All legal moves have been searched. A special case: if we're in check - // and no legal moves were found, it is checkmate. - if (ss->inCheck && bestValue == -VALUE_INFINITE) - { - assert(!MoveList(pos).size()); + // Step 9. Check for mate + // All legal moves have been searched. A special case: if we're in check + // and no legal moves were found, it is checkmate. + if (ss->inCheck && bestValue == -VALUE_INFINITE) { + assert(!MoveList(pos).size()); - return mated_in(ss->ply); // Plies to mate from the root - } + return mated_in(ss->ply); // Plies to mate from the root + } - // Save gathered info in transposition table - tte->save(posKey, value_to_tt(bestValue, ss->ply), pvHit, - bestValue >= beta ? BOUND_LOWER : BOUND_UPPER, - ttDepth, bestMove, ss->staticEval); + // Save gathered info in transposition table + tte->save(posKey, value_to_tt(bestValue, ss->ply), pvHit, + bestValue >= beta ? BOUND_LOWER : BOUND_UPPER, ttDepth, bestMove, + ss->staticEval); - assert(bestValue > -VALUE_INFINITE && bestValue < VALUE_INFINITE); + assert(bestValue > -VALUE_INFINITE && bestValue < VALUE_INFINITE); - return bestValue; - } + return bestValue; + } - // value_to_tt() adjusts a mate or TB score from "plies to mate from the root" to - // "plies to mate from the current position". Standard scores are unchanged. - // The function is called before storing a value in the transposition table. + // value_to_tt() adjusts a mate or TB score from "plies to mate from the root" to + // "plies to mate from the current position". Standard scores are unchanged. + // The function is called before storing a value in the transposition table. - Value value_to_tt(Value v, int ply) { + Value value_to_tt(Value v, int ply) { - assert(v != VALUE_NONE); + assert(v != VALUE_NONE); - return v >= VALUE_TB_WIN_IN_MAX_PLY ? v + ply - : v <= VALUE_TB_LOSS_IN_MAX_PLY ? v - ply : v; - } + return v >= VALUE_TB_WIN_IN_MAX_PLY ? v + ply : + v <= VALUE_TB_LOSS_IN_MAX_PLY ? v - ply : + v; + } - // value_from_tt() is the inverse of value_to_tt(): it adjusts a mate or TB score - // from the transposition table (which refers to the plies to mate/be mated from - // current position) to "plies to mate/be mated (TB win/loss) from the root". However, - // for mate scores, to avoid potentially false mate scores related to the 50 moves rule - // and the graph history interaction, we return an optimal TB score instead. + // value_from_tt() is the inverse of value_to_tt(): it adjusts a mate or TB score + // from the transposition table (which refers to the plies to mate/be mated from + // current position) to "plies to mate/be mated (TB win/loss) from the root". However, + // for mate scores, to avoid potentially false mate scores related to the 50 moves rule + // and the graph history interaction, we return an optimal TB score instead. - Value value_from_tt(Value v, int ply, int r50c) { + Value value_from_tt(Value v, int ply, int r50c) { - if (v == VALUE_NONE) - return VALUE_NONE; + if (v == VALUE_NONE) return VALUE_NONE; - if (v >= VALUE_TB_WIN_IN_MAX_PLY) // TB win or better - { - if (v >= VALUE_MATE_IN_MAX_PLY && VALUE_MATE - v > 99 - r50c) - return VALUE_MATE_IN_MAX_PLY - 1; // do not return a potentially false mate score + if (v >= VALUE_TB_WIN_IN_MAX_PLY) // TB win or better + { + if (v >= VALUE_MATE_IN_MAX_PLY && VALUE_MATE - v > 99 - r50c) + return VALUE_MATE_IN_MAX_PLY - + 1; // do not return a potentially false mate score - return v - ply; - } + return v - ply; + } - if (v <= VALUE_TB_LOSS_IN_MAX_PLY) // TB loss or worse - { - if (v <= VALUE_MATED_IN_MAX_PLY && VALUE_MATE + v > 99 - r50c) - return VALUE_MATED_IN_MAX_PLY + 1; // do not return a potentially false mate score + if (v <= VALUE_TB_LOSS_IN_MAX_PLY) // TB loss or worse + { + if (v <= VALUE_MATED_IN_MAX_PLY && VALUE_MATE + v > 99 - r50c) + return VALUE_MATED_IN_MAX_PLY + + 1; // do not return a potentially false mate score - return v + ply; - } + return v + ply; + } - return v; - } + return v; + } - // update_pv() adds current move and appends child pv[] + // update_pv() adds current move and appends child pv[] - void update_pv(Move* pv, Move move, const Move* childPv) { + void update_pv(Move* pv, Move move, const Move* childPv) { - for (*pv++ = move; childPv && *childPv != MOVE_NONE; ) - *pv++ = *childPv++; - *pv = MOVE_NONE; - } + for (*pv++ = move; childPv && *childPv != MOVE_NONE;) *pv++ = *childPv++; + *pv = MOVE_NONE; + } - // update_all_stats() updates stats at the end of search() when a bestMove is found + // update_all_stats() updates stats at the end of search() when a bestMove is found - void update_all_stats(const Position& pos, Stack* ss, Move bestMove, Value bestValue, Value beta, Square prevSq, - Move* quietsSearched, int quietCount, Move* capturesSearched, int captureCount, Depth depth) { + void update_all_stats(const Position& pos, Stack* ss, Move bestMove, Value bestValue, + Value beta, Square prevSq, Move* quietsSearched, int quietCount, + Move* capturesSearched, int captureCount, Depth depth) { - Color us = pos.side_to_move(); - Thread* thisThread = pos.this_thread(); - CapturePieceToHistory& captureHistory = thisThread->captureHistory; - Piece moved_piece = pos.moved_piece(bestMove); - PieceType captured; + Color us = pos.side_to_move(); + Thread* thisThread = pos.this_thread(); + CapturePieceToHistory& captureHistory = thisThread->captureHistory; + Piece moved_piece = pos.moved_piece(bestMove); + PieceType captured; - int quietMoveBonus = stat_bonus(depth + 1); + int quietMoveBonus = stat_bonus(depth + 1); - if (!pos.capture_stage(bestMove)) - { - int bestMoveBonus = bestValue > beta + 145 ? quietMoveBonus // larger bonus - : stat_bonus(depth); // smaller bonus + if (!pos.capture_stage(bestMove)) { + int bestMoveBonus = bestValue > beta + 145 ? quietMoveBonus // larger bonus + : + stat_bonus(depth); // smaller bonus - // Increase stats for the best move in case it was a quiet move - update_quiet_stats(pos, ss, bestMove, bestMoveBonus); + // Increase stats for the best move in case it was a quiet move + update_quiet_stats(pos, ss, bestMove, bestMoveBonus); - // Decrease stats for all non-best quiet moves - for (int i = 0; i < quietCount; ++i) - { - thisThread->mainHistory[us][from_to(quietsSearched[i])] << -bestMoveBonus; - update_continuation_histories(ss, pos.moved_piece(quietsSearched[i]), to_sq(quietsSearched[i]), -bestMoveBonus); - } - } - else - { - // Increase stats for the best move in case it was a capture move - captured = type_of(pos.piece_on(to_sq(bestMove))); - captureHistory[moved_piece][to_sq(bestMove)][captured] << quietMoveBonus; - } + // Decrease stats for all non-best quiet moves + for (int i = 0; i < quietCount; ++i) { + thisThread->mainHistory[us][from_to(quietsSearched[i])] << -bestMoveBonus; + update_continuation_histories(ss, pos.moved_piece(quietsSearched[i]), + to_sq(quietsSearched[i]), -bestMoveBonus); + } + } else { + // Increase stats for the best move in case it was a capture move + captured = type_of(pos.piece_on(to_sq(bestMove))); + captureHistory[moved_piece][to_sq(bestMove)][captured] << quietMoveBonus; + } - // Extra penalty for a quiet early move that was not a TT move or - // main killer move in previous ply when it gets refuted. - if ( prevSq != SQ_NONE - && ((ss-1)->moveCount == 1 + (ss-1)->ttHit || ((ss-1)->currentMove == (ss-1)->killers[0])) - && !pos.captured_piece()) - update_continuation_histories(ss-1, pos.piece_on(prevSq), prevSq, -quietMoveBonus); - - // Decrease stats for all non-best capture moves - for (int i = 0; i < captureCount; ++i) - { - moved_piece = pos.moved_piece(capturesSearched[i]); - captured = type_of(pos.piece_on(to_sq(capturesSearched[i]))); - captureHistory[moved_piece][to_sq(capturesSearched[i])][captured] << -quietMoveBonus; - } - } + // Extra penalty for a quiet early move that was not a TT move or + // main killer move in previous ply when it gets refuted. + if (prevSq != SQ_NONE && + ((ss - 1)->moveCount == 1 + (ss - 1)->ttHit || + ((ss - 1)->currentMove == (ss - 1)->killers[0])) && + !pos.captured_piece()) + update_continuation_histories(ss - 1, pos.piece_on(prevSq), prevSq, + -quietMoveBonus); + + // Decrease stats for all non-best capture moves + for (int i = 0; i < captureCount; ++i) { + moved_piece = pos.moved_piece(capturesSearched[i]); + captured = type_of(pos.piece_on(to_sq(capturesSearched[i]))); + captureHistory[moved_piece][to_sq(capturesSearched[i])][captured] + << -quietMoveBonus; + } + } - // update_continuation_histories() updates histories of the move pairs formed - // by moves at ply -1, -2, -4, and -6 with current move. + // update_continuation_histories() updates histories of the move pairs formed + // by moves at ply -1, -2, -4, and -6 with current move. - void update_continuation_histories(Stack* ss, Piece pc, Square to, int bonus) { + void update_continuation_histories(Stack* ss, Piece pc, Square to, int bonus) { - for (int i : {1, 2, 4, 6}) - { - // Only update the first 2 continuation histories if we are in check - if (ss->inCheck && i > 2) - break; - if (is_ok((ss-i)->currentMove)) - (*(ss-i)->continuationHistory)[pc][to] << bonus; - } - } + for (int i : {1, 2, 4, 6}) { + // Only update the first 2 continuation histories if we are in check + if (ss->inCheck && i > 2) break; + if (is_ok((ss - i)->currentMove)) (*(ss - i)->continuationHistory)[pc][to] << bonus; + } + } - // update_quiet_stats() updates move sorting heuristics + // update_quiet_stats() updates move sorting heuristics - void update_quiet_stats(const Position& pos, Stack* ss, Move move, int bonus) { + void update_quiet_stats(const Position& pos, Stack* ss, Move move, int bonus) { - // Update killers - if (ss->killers[0] != move) - { - ss->killers[1] = ss->killers[0]; - ss->killers[0] = move; - } + // Update killers + if (ss->killers[0] != move) { + ss->killers[1] = ss->killers[0]; + ss->killers[0] = move; + } - Color us = pos.side_to_move(); - Thread* thisThread = pos.this_thread(); - thisThread->mainHistory[us][from_to(move)] << bonus; - update_continuation_histories(ss, pos.moved_piece(move), to_sq(move), bonus); + Color us = pos.side_to_move(); + Thread* thisThread = pos.this_thread(); + thisThread->mainHistory[us][from_to(move)] << bonus; + update_continuation_histories(ss, pos.moved_piece(move), to_sq(move), bonus); - // Update countermove history - if (is_ok((ss-1)->currentMove)) - { - Square prevSq = to_sq((ss-1)->currentMove); - thisThread->counterMoves[pos.piece_on(prevSq)][prevSq] = move; - } - } - - // When playing with strength handicap, choose the best move among a set of RootMoves - // using a statistical rule dependent on 'level'. Idea by Heinz van Saanen. - - Move Skill::pick_best(size_t multiPV) { - - const RootMoves& rootMoves = Threads.main()->rootMoves; - static PRNG rng(now()); // PRNG sequence should be non-deterministic - - // RootMoves are already sorted by score in descending order - Value topScore = rootMoves[0].score; - int delta = std::min(topScore - rootMoves[multiPV - 1].score, PawnValue); - int maxScore = -VALUE_INFINITE; - double weakness = 120 - 2 * level; - - // Choose best move. For each move score we add two terms, both dependent on - // weakness. One is deterministic and bigger for weaker levels, and one is - // random. Then we choose the move with the resulting highest score. - for (size_t i = 0; i < multiPV; ++i) - { - // This is our magic formula - int push = int(( weakness * int(topScore - rootMoves[i].score) - + delta * (rng.rand() % int(weakness))) / 128); - - if (rootMoves[i].score + push >= maxScore) - { - maxScore = rootMoves[i].score + push; - best = rootMoves[i].pv[0]; + // Update countermove history + if (is_ok((ss - 1)->currentMove)) { + Square prevSq = to_sq((ss - 1)->currentMove); + thisThread->counterMoves[pos.piece_on(prevSq)][prevSq] = move; + } } - } - return best; - } + // When playing with strength handicap, choose the best move among a set of RootMoves + // using a statistical rule dependent on 'level'. Idea by Heinz van Saanen. -} // namespace + Move Skill::pick_best(size_t multiPV) { + const RootMoves& rootMoves = Threads.main()->rootMoves; + static PRNG rng(now()); // PRNG sequence should be non-deterministic -/// MainThread::check_time() is used to print debug info and, more importantly, -/// to detect when we are out of available time and thus stop the search. + // RootMoves are already sorted by score in descending order + Value topScore = rootMoves[0].score; + int delta = std::min(topScore - rootMoves[multiPV - 1].score, PawnValue); + int maxScore = -VALUE_INFINITE; + double weakness = 120 - 2 * level; -void MainThread::check_time() { + // Choose best move. For each move score we add two terms, both dependent on + // weakness. One is deterministic and bigger for weaker levels, and one is + // random. Then we choose the move with the resulting highest score. + for (size_t i = 0; i < multiPV; ++i) { + // This is our magic formula + int push = int((weakness * int(topScore - rootMoves[i].score) + + delta * (rng.rand() % int(weakness))) / + 128); - if (--callsCnt > 0) - return; + if (rootMoves[i].score + push >= maxScore) { + maxScore = rootMoves[i].score + push; + best = rootMoves[i].pv[0]; + } + } - // When using nodes, ensure checking rate is not lower than 0.1% of nodes - callsCnt = Limits.nodes ? std::min(512, int(Limits.nodes / 1024)) : 512; + return best; + } - static TimePoint lastInfoTime = now(); + } // namespace - TimePoint elapsed = Time.elapsed(); - TimePoint tick = Limits.startTime + elapsed; - if (tick - lastInfoTime >= 1000) - { - lastInfoTime = tick; - dbg_print(); - } + /// MainThread::check_time() is used to print debug info and, more importantly, + /// to detect when we are out of available time and thus stop the search. - // We should not stop pondering until told so by the GUI - if (ponder) - return; + void MainThread::check_time() { - if ( (Limits.use_time_management() && (elapsed > Time.maximum() || stopOnPonderhit)) - || (Limits.movetime && elapsed >= Limits.movetime) - || (Limits.nodes && Threads.nodes_searched() >= (uint64_t)Limits.nodes)) - Threads.stop = true; -} + if (--callsCnt > 0) return; + // When using nodes, ensure checking rate is not lower than 0.1% of nodes + callsCnt = Limits.nodes ? std::min(512, int(Limits.nodes / 1024)) : 512; -/// UCI::pv() formats PV information according to the UCI protocol. UCI requires -/// that all (if any) unsearched PV lines are sent using a previous search score. + static TimePoint lastInfoTime = now(); -string UCI::pv(const Position& pos, Depth depth) { + TimePoint elapsed = Time.elapsed(); + TimePoint tick = Limits.startTime + elapsed; - std::stringstream ss; - TimePoint elapsed = Time.elapsed() + 1; - const RootMoves& rootMoves = pos.this_thread()->rootMoves; - size_t pvIdx = pos.this_thread()->pvIdx; - size_t multiPV = std::min((size_t)Options["MultiPV"], rootMoves.size()); - uint64_t nodesSearched = Threads.nodes_searched(); - uint64_t tbHits = Threads.tb_hits() + (TB::RootInTB ? rootMoves.size() : 0); + if (tick - lastInfoTime >= 1000) { + lastInfoTime = tick; + dbg_print(); + } - for (size_t i = 0; i < multiPV; ++i) - { - bool updated = rootMoves[i].score != -VALUE_INFINITE; + // We should not stop pondering until told so by the GUI + if (ponder) return; - if (depth == 1 && !updated && i > 0) - continue; + if ((Limits.use_time_management() && (elapsed > Time.maximum() || stopOnPonderhit)) || + (Limits.movetime && elapsed >= Limits.movetime) || + (Limits.nodes && Threads.nodes_searched() >= (uint64_t) Limits.nodes)) + Threads.stop = true; + } - Depth d = updated ? depth : std::max(1, depth - 1); - Value v = updated ? rootMoves[i].uciScore : rootMoves[i].previousScore; - if (v == -VALUE_INFINITE) - v = VALUE_ZERO; + /// UCI::pv() formats PV information according to the UCI protocol. UCI requires + /// that all (if any) unsearched PV lines are sent using a previous search score. - bool tb = TB::RootInTB && abs(v) < VALUE_MATE_IN_MAX_PLY; - v = tb ? rootMoves[i].tbScore : v; + string UCI::pv(const Position& pos, Depth depth) { - if (ss.rdbuf()->in_avail()) // Not at first line - ss << "\n"; + std::stringstream ss; + TimePoint elapsed = Time.elapsed() + 1; + const RootMoves& rootMoves = pos.this_thread()->rootMoves; + size_t pvIdx = pos.this_thread()->pvIdx; + size_t multiPV = std::min((size_t) Options["MultiPV"], rootMoves.size()); + uint64_t nodesSearched = Threads.nodes_searched(); + uint64_t tbHits = Threads.tb_hits() + (TB::RootInTB ? rootMoves.size() : 0); - ss << "info" - << " depth " << d - << " seldepth " << rootMoves[i].selDepth - << " multipv " << i + 1 - << " score " << UCI::value(v); + for (size_t i = 0; i < multiPV; ++i) { + bool updated = rootMoves[i].score != -VALUE_INFINITE; - if (Options["UCI_ShowWDL"]) - ss << UCI::wdl(v, pos.game_ply()); + if (depth == 1 && !updated && i > 0) continue; - if (i == pvIdx && !tb && updated) // tablebase- and previous-scores are exact - ss << (rootMoves[i].scoreLowerbound ? " lowerbound" : (rootMoves[i].scoreUpperbound ? " upperbound" : "")); + Depth d = updated ? depth : std::max(1, depth - 1); + Value v = updated ? rootMoves[i].uciScore : rootMoves[i].previousScore; - ss << " nodes " << nodesSearched - << " nps " << nodesSearched * 1000 / elapsed - << " hashfull " << TT.hashfull() - << " tbhits " << tbHits - << " time " << elapsed - << " pv"; + if (v == -VALUE_INFINITE) v = VALUE_ZERO; - for (Move m : rootMoves[i].pv) - ss << " " << UCI::move(m, pos.is_chess960()); - } + bool tb = TB::RootInTB && abs(v) < VALUE_MATE_IN_MAX_PLY; + v = tb ? rootMoves[i].tbScore : v; - return ss.str(); -} + if (ss.rdbuf()->in_avail()) // Not at first line + ss << "\n"; + ss << "info" + << " depth " << d << " seldepth " << rootMoves[i].selDepth << " multipv " << i + 1 + << " score " << UCI::value(v); -/// RootMove::extract_ponder_from_tt() is called in case we have no ponder move -/// before exiting the search, for instance, in case we stop the search during a -/// fail high at root. We try hard to have a ponder move to return to the GUI, -/// otherwise in case of 'ponder on' we have nothing to think about. + if (Options["UCI_ShowWDL"]) ss << UCI::wdl(v, pos.game_ply()); -bool RootMove::extract_ponder_from_tt(Position& pos) { + if (i == pvIdx && !tb && updated) // tablebase- and previous-scores are exact + ss << (rootMoves[i].scoreLowerbound ? + " lowerbound" : + (rootMoves[i].scoreUpperbound ? " upperbound" : "")); - StateInfo st; - ASSERT_ALIGNED(&st, Eval::NNUE::CacheLineSize); + ss << " nodes " << nodesSearched << " nps " << nodesSearched * 1000 / elapsed + << " hashfull " << TT.hashfull() << " tbhits " << tbHits << " time " << elapsed + << " pv"; - bool ttHit; + for (Move m : rootMoves[i].pv) ss << " " << UCI::move(m, pos.is_chess960()); + } - assert(pv.size() == 1); + return ss.str(); + } - if (pv[0] == MOVE_NONE) - return false; - pos.do_move(pv[0], st); - TTEntry* tte = TT.probe(pos.key(), ttHit); + /// RootMove::extract_ponder_from_tt() is called in case we have no ponder move + /// before exiting the search, for instance, in case we stop the search during a + /// fail high at root. We try hard to have a ponder move to return to the GUI, + /// otherwise in case of 'ponder on' we have nothing to think about. - if (ttHit) - { - Move m = tte->move(); // Local copy to be SMP safe - if (MoveList(pos).contains(m)) - pv.push_back(m); - } + bool RootMove::extract_ponder_from_tt(Position& pos) { - pos.undo_move(pv[0]); - return pv.size() > 1; -} + StateInfo st; + ASSERT_ALIGNED(&st, Eval::NNUE::CacheLineSize); -void Tablebases::rank_root_moves(Position& pos, Search::RootMoves& rootMoves) { + bool ttHit; - RootInTB = false; - UseRule50 = bool(Options["Syzygy50MoveRule"]); - ProbeDepth = int(Options["SyzygyProbeDepth"]); - Cardinality = int(Options["SyzygyProbeLimit"]); - bool dtz_available = true; + assert(pv.size() == 1); - // Tables with fewer pieces than SyzygyProbeLimit are searched with - // ProbeDepth == DEPTH_ZERO - if (Cardinality > MaxCardinality) - { - Cardinality = MaxCardinality; - ProbeDepth = 0; - } + if (pv[0] == MOVE_NONE) return false; - if (Cardinality >= popcount(pos.pieces()) && !pos.can_castle(ANY_CASTLING)) - { - // Rank moves using DTZ tables - RootInTB = root_probe(pos, rootMoves); + pos.do_move(pv[0], st); + TTEntry* tte = TT.probe(pos.key(), ttHit); - if (!RootInTB) - { - // DTZ tables are missing; try to rank moves using WDL tables - dtz_available = false; - RootInTB = root_probe_wdl(pos, rootMoves); + if (ttHit) { + Move m = tte->move(); // Local copy to be SMP safe + if (MoveList(pos).contains(m)) pv.push_back(m); } + + pos.undo_move(pv[0]); + return pv.size() > 1; } - if (RootInTB) - { - // Sort moves according to TB rank - std::stable_sort(rootMoves.begin(), rootMoves.end(), - [](const RootMove &a, const RootMove &b) { return a.tbRank > b.tbRank; } ); + void Tablebases::rank_root_moves(Position& pos, Search::RootMoves& rootMoves) { - // Probe during search only if DTZ is not available and we are winning - if (dtz_available || rootMoves[0].tbScore <= VALUE_DRAW) - Cardinality = 0; - } - else - { - // Clean up if root_probe() and root_probe_wdl() have failed - for (auto& m : rootMoves) - m.tbRank = 0; + RootInTB = false; + UseRule50 = bool(Options["Syzygy50MoveRule"]); + ProbeDepth = int(Options["SyzygyProbeDepth"]); + Cardinality = int(Options["SyzygyProbeLimit"]); + bool dtz_available = true; + + // Tables with fewer pieces than SyzygyProbeLimit are searched with + // ProbeDepth == DEPTH_ZERO + if (Cardinality > MaxCardinality) { + Cardinality = MaxCardinality; + ProbeDepth = 0; + } + + if (Cardinality >= popcount(pos.pieces()) && !pos.can_castle(ANY_CASTLING)) { + // Rank moves using DTZ tables + RootInTB = root_probe(pos, rootMoves); + + if (!RootInTB) { + // DTZ tables are missing; try to rank moves using WDL tables + dtz_available = false; + RootInTB = root_probe_wdl(pos, rootMoves); + } + } + + if (RootInTB) { + // Sort moves according to TB rank + std::stable_sort( + rootMoves.begin(), rootMoves.end(), + [](const RootMove& a, const RootMove& b) { return a.tbRank > b.tbRank; }); + + // Probe during search only if DTZ is not available and we are winning + if (dtz_available || rootMoves[0].tbScore <= VALUE_DRAW) Cardinality = 0; + } else { + // Clean up if root_probe() and root_probe_wdl() have failed + for (auto& m : rootMoves) m.tbRank = 0; + } } -} -} // namespace Stockfish +} // namespace Stockfish diff --git a/src/search.h b/src/search.h index c6dbffce0c7..38afe1b9e33 100644 --- a/src/search.h +++ b/src/search.h @@ -28,90 +28,88 @@ namespace Stockfish { -class Position; + class Position; -namespace Search { + namespace Search { -/// Stack struct keeps track of the information we need to remember from nodes -/// shallower and deeper in the tree during the search. Each search thread has -/// its own array of Stack objects, indexed by the current ply. + /// Stack struct keeps track of the information we need to remember from nodes + /// shallower and deeper in the tree during the search. Each search thread has + /// its own array of Stack objects, indexed by the current ply. -struct Stack { - Move* pv; - PieceToHistory* continuationHistory; - int ply; - Move currentMove; - Move excludedMove; - Move killers[2]; - Value staticEval; - int statScore; - int moveCount; - bool inCheck; - bool ttPv; - bool ttHit; - int doubleExtensions; - int cutoffCnt; -}; + struct Stack { + Move* pv; + PieceToHistory* continuationHistory; + int ply; + Move currentMove; + Move excludedMove; + Move killers[2]; + Value staticEval; + int statScore; + int moveCount; + bool inCheck; + bool ttPv; + bool ttHit; + int doubleExtensions; + int cutoffCnt; + }; -/// RootMove struct is used for moves at the root of the tree. For each root move -/// we store a score and a PV (really a refutation in the case of moves which -/// fail low). Score is normally set at -VALUE_INFINITE for all non-pv moves. + /// RootMove struct is used for moves at the root of the tree. For each root move + /// we store a score and a PV (really a refutation in the case of moves which + /// fail low). Score is normally set at -VALUE_INFINITE for all non-pv moves. -struct RootMove { + struct RootMove { - explicit RootMove(Move m) : pv(1, m) {} - bool extract_ponder_from_tt(Position& pos); - bool operator==(const Move& m) const { return pv[0] == m; } - bool operator<(const RootMove& m) const { // Sort in descending order - return m.score != score ? m.score < score - : m.previousScore < previousScore; - } + explicit RootMove(Move m) : pv(1, m) {} + bool extract_ponder_from_tt(Position& pos); + bool operator==(const Move& m) const { return pv[0] == m; } + bool operator<(const RootMove& m) const { // Sort in descending order + return m.score != score ? m.score < score : m.previousScore < previousScore; + } - Value score = -VALUE_INFINITE; - Value previousScore = -VALUE_INFINITE; - Value averageScore = -VALUE_INFINITE; - Value uciScore = -VALUE_INFINITE; - bool scoreLowerbound = false; - bool scoreUpperbound = false; - int selDepth = 0; - int tbRank = 0; - Value tbScore; - std::vector pv; -}; + Value score = -VALUE_INFINITE; + Value previousScore = -VALUE_INFINITE; + Value averageScore = -VALUE_INFINITE; + Value uciScore = -VALUE_INFINITE; + bool scoreLowerbound = false; + bool scoreUpperbound = false; + int selDepth = 0; + int tbRank = 0; + Value tbScore; + std::vector pv; + }; -using RootMoves = std::vector; + using RootMoves = std::vector; -/// LimitsType struct stores information sent by GUI about available time to -/// search the current move, maximum depth/time, or if we are in analysis mode. + /// LimitsType struct stores information sent by GUI about available time to + /// search the current move, maximum depth/time, or if we are in analysis mode. -struct LimitsType { + struct LimitsType { - LimitsType() { // Init explicitly due to broken value-initialization of non POD in MSVC - time[WHITE] = time[BLACK] = inc[WHITE] = inc[BLACK] = npmsec = movetime = TimePoint(0); - movestogo = depth = mate = perft = infinite = 0; - nodes = 0; - } + LimitsType() { // Init explicitly due to broken value-initialization of non POD in MSVC + time[WHITE] = time[BLACK] = inc[WHITE] = inc[BLACK] = npmsec = movetime = + TimePoint(0); + movestogo = depth = mate = perft = infinite = 0; + nodes = 0; + } - bool use_time_management() const { - return time[WHITE] || time[BLACK]; - } + bool use_time_management() const { return time[WHITE] || time[BLACK]; } - std::vector searchmoves; - TimePoint time[COLOR_NB], inc[COLOR_NB], npmsec, movetime, startTime; - int movestogo, depth, mate, perft, infinite; - int64_t nodes; -}; + std::vector searchmoves; + TimePoint time[COLOR_NB], inc[COLOR_NB], npmsec, movetime, startTime; + int movestogo, depth, mate, perft, infinite; + int64_t nodes; + }; -extern LimitsType Limits; + extern LimitsType Limits; -void init(); -void clear(); + void init(); + void clear(); -} // namespace Search + } // namespace Search -} // namespace Stockfish +} // namespace Stockfish -#endif // #ifndef SEARCH_H_INCLUDED +#endif // #ifndef SEARCH_H_INCLUDED diff --git a/src/syzygy/tbprobe.cpp b/src/syzygy/tbprobe.cpp index d1b32d242c9..5897da90a5f 100644 --- a/src/syzygy/tbprobe.cpp +++ b/src/syzygy/tbprobe.cpp @@ -45,15 +45,15 @@ #include "../uci.h" #ifndef _WIN32 -#include -#include -#include + #include + #include + #include #else -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -# define NOMINMAX // Disable macros min() and max() -#endif -#include + #define WIN32_LEAN_AND_MEAN + #ifndef NOMINMAX + #define NOMINMAX // Disable macros min() and max() + #endif + #include #endif using namespace Stockfish::Tablebases; @@ -62,1576 +62,1502 @@ int Stockfish::Tablebases::MaxCardinality; namespace Stockfish { -namespace { - -constexpr int TBPIECES = 7; // Max number of supported pieces -constexpr int MAX_DTZ = 1 << 18; // Max DTZ supported, large enough to deal with the syzygy TB limit. - -enum { BigEndian, LittleEndian }; -enum TBType { WDL, DTZ }; // Used as template parameter - -// Each table has a set of flags: all of them refer to DTZ tables, the last one to WDL tables -enum TBFlag { STM = 1, Mapped = 2, WinPlies = 4, LossPlies = 8, Wide = 16, SingleValue = 128 }; - -inline WDLScore operator-(WDLScore d) { return WDLScore(-int(d)); } -inline Square operator^(Square s, int i) { return Square(int(s) ^ i); } - -constexpr std::string_view PieceToChar = " PNBRQK pnbrqk"; - -int MapPawns[SQUARE_NB]; -int MapB1H1H7[SQUARE_NB]; -int MapA1D1D4[SQUARE_NB]; -int MapKK[10][SQUARE_NB]; // [MapA1D1D4][SQUARE_NB] - -int Binomial[6][SQUARE_NB]; // [k][n] k elements from a set of n elements -int LeadPawnIdx[6][SQUARE_NB]; // [leadPawnsCnt][SQUARE_NB] -int LeadPawnsSize[6][4]; // [leadPawnsCnt][FILE_A..FILE_D] - -// Comparison function to sort leading pawns in ascending MapPawns[] order -bool pawns_comp(Square i, Square j) { return MapPawns[i] < MapPawns[j]; } -int off_A1H8(Square sq) { return int(rank_of(sq)) - file_of(sq); } - -constexpr Value WDL_to_value[] = { - -VALUE_MATE + MAX_PLY + 1, - VALUE_DRAW - 2, - VALUE_DRAW, - VALUE_DRAW + 2, - VALUE_MATE - MAX_PLY - 1 -}; - -template -inline void swap_endian(T& x) -{ - static_assert(std::is_unsigned::value, "Argument of swap_endian not unsigned"); - - uint8_t tmp, *c = (uint8_t*)&x; - for (int i = 0; i < Half; ++i) - tmp = c[i], c[i] = c[End - i], c[End - i] = tmp; -} -template<> inline void swap_endian(uint8_t&) {} - -template T number(void* addr) -{ - T v; - - if ((uintptr_t)addr & (alignof(T) - 1)) // Unaligned pointer (very rare) - std::memcpy(&v, addr, sizeof(T)); - else - v = *((T*)addr); - - if (LE != IsLittleEndian) - swap_endian(v); - return v; -} - -// DTZ tables don't store valid scores for moves that reset the rule50 counter -// like captures and pawn moves but we can easily recover the correct dtz of the -// previous move if we know the position's WDL score. -int dtz_before_zeroing(WDLScore wdl) { - return wdl == WDLWin ? 1 : - wdl == WDLCursedWin ? 101 : - wdl == WDLBlessedLoss ? -101 : - wdl == WDLLoss ? -1 : 0; -} - -// Return the sign of a number (-1, 0, 1) -template int sign_of(T val) { - return (T(0) < val) - (val < T(0)); -} - -// Numbers in little endian used by sparseIndex[] to point into blockLength[] -struct SparseEntry { - char block[4]; // Number of block - char offset[2]; // Offset within the block -}; - -static_assert(sizeof(SparseEntry) == 6, "SparseEntry must be 6 bytes"); - -using Sym = uint16_t; // Huffman symbol - -struct LR { - enum Side { Left, Right }; - - uint8_t lr[3]; // The first 12 bits is the left-hand symbol, the second 12 - // bits is the right-hand symbol. If symbol has length 1, - // then the left-hand symbol is the stored value. - template - Sym get() { - return S == Left ? ((lr[1] & 0xF) << 8) | lr[0] : - S == Right ? (lr[2] << 4) | (lr[1] >> 4) : (assert(false), Sym(-1)); - } -}; + namespace { + + constexpr int TBPIECES = 7; // Max number of supported pieces + constexpr int MAX_DTZ = + 1 << 18; // Max DTZ supported, large enough to deal with the syzygy TB limit. + + enum { + BigEndian, + LittleEndian + }; + enum TBType { + WDL, + DTZ + }; // Used as template parameter + + // Each table has a set of flags: all of them refer to DTZ tables, the last one to WDL tables + enum TBFlag { + STM = 1, + Mapped = 2, + WinPlies = 4, + LossPlies = 8, + Wide = 16, + SingleValue = 128 + }; + + inline WDLScore operator-(WDLScore d) { return WDLScore(-int(d)); } + inline Square operator^(Square s, int i) { return Square(int(s) ^ i); } + + constexpr std::string_view PieceToChar = " PNBRQK pnbrqk"; + + int MapPawns[SQUARE_NB]; + int MapB1H1H7[SQUARE_NB]; + int MapA1D1D4[SQUARE_NB]; + int MapKK[10][SQUARE_NB]; // [MapA1D1D4][SQUARE_NB] + + int Binomial[6][SQUARE_NB]; // [k][n] k elements from a set of n elements + int LeadPawnIdx[6][SQUARE_NB]; // [leadPawnsCnt][SQUARE_NB] + int LeadPawnsSize[6][4]; // [leadPawnsCnt][FILE_A..FILE_D] + + // Comparison function to sort leading pawns in ascending MapPawns[] order + bool pawns_comp(Square i, Square j) { return MapPawns[i] < MapPawns[j]; } + int off_A1H8(Square sq) { return int(rank_of(sq)) - file_of(sq); } + + constexpr Value WDL_to_value[] = {-VALUE_MATE + MAX_PLY + 1, VALUE_DRAW - 2, VALUE_DRAW, + VALUE_DRAW + 2, VALUE_MATE - MAX_PLY - 1}; + + template + inline void swap_endian(T& x) { + static_assert(std::is_unsigned::value, "Argument of swap_endian not unsigned"); + + uint8_t tmp, *c = (uint8_t*) &x; + for (int i = 0; i < Half; ++i) tmp = c[i], c[i] = c[End - i], c[End - i] = tmp; + } + template<> inline void swap_endian(uint8_t&) {} -static_assert(sizeof(LR) == 3, "LR tree entry must be 3 bytes"); + template T number(void* addr) { + T v; -// Tablebases data layout is structured as following: -// -// TBFile: memory maps/unmaps the physical .rtbw and .rtbz files -// TBTable: one object for each file with corresponding indexing information -// TBTables: has ownership of TBTable objects, keeping a list and a hash + if ((uintptr_t) addr & (alignof(T) - 1)) // Unaligned pointer (very rare) + std::memcpy(&v, addr, sizeof(T)); + else + v = *((T*) addr); -// class TBFile memory maps/unmaps the single .rtbw and .rtbz files. Files are -// memory mapped for best performance. Files are mapped at first access: at init -// time only existence of the file is checked. -class TBFile : public std::ifstream { + if (LE != IsLittleEndian) swap_endian(v); + return v; + } - std::string fname; + // DTZ tables don't store valid scores for moves that reset the rule50 counter + // like captures and pawn moves but we can easily recover the correct dtz of the + // previous move if we know the position's WDL score. + int dtz_before_zeroing(WDLScore wdl) { + return wdl == WDLWin ? 1 : + wdl == WDLCursedWin ? 101 : + wdl == WDLBlessedLoss ? -101 : + wdl == WDLLoss ? -1 : + 0; + } -public: - // Look for and open the file among the Paths directories where the .rtbw - // and .rtbz files can be found. Multiple directories are separated by ";" - // on Windows and by ":" on Unix-based operating systems. - // - // Example: - // C:\tb\wdl345;C:\tb\wdl6;D:\tb\dtz345;D:\tb\dtz6 - static std::string Paths; + // Return the sign of a number (-1, 0, 1) + template int sign_of(T val) { return (T(0) < val) - (val < T(0)); } + + // Numbers in little endian used by sparseIndex[] to point into blockLength[] + struct SparseEntry { + char block[4]; // Number of block + char offset[2]; // Offset within the block + }; + + static_assert(sizeof(SparseEntry) == 6, "SparseEntry must be 6 bytes"); + + using Sym = uint16_t; // Huffman symbol + + struct LR { + enum Side { + Left, + Right + }; + + uint8_t lr[3]; // The first 12 bits is the left-hand symbol, the second 12 + // bits is the right-hand symbol. If symbol has length 1, + // then the left-hand symbol is the stored value. + template Sym get() { + return S == Left ? ((lr[1] & 0xF) << 8) | lr[0] : + S == Right ? (lr[2] << 4) | (lr[1] >> 4) : + (assert(false), Sym(-1)); + } + }; - TBFile(const std::string& f) { + static_assert(sizeof(LR) == 3, "LR tree entry must be 3 bytes"); + + // Tablebases data layout is structured as following: + // + // TBFile: memory maps/unmaps the physical .rtbw and .rtbz files + // TBTable: one object for each file with corresponding indexing information + // TBTables: has ownership of TBTable objects, keeping a list and a hash + + // class TBFile memory maps/unmaps the single .rtbw and .rtbz files. Files are + // memory mapped for best performance. Files are mapped at first access: at init + // time only existence of the file is checked. + class TBFile: public std::ifstream { + + std::string fname; + + public: + // Look for and open the file among the Paths directories where the .rtbw + // and .rtbz files can be found. Multiple directories are separated by ";" + // on Windows and by ":" on Unix-based operating systems. + // + // Example: + // C:\tb\wdl345;C:\tb\wdl6;D:\tb\dtz345;D:\tb\dtz6 + static std::string Paths; + + TBFile(const std::string& f) { #ifndef _WIN32 - constexpr char SepChar = ':'; + constexpr char SepChar = ':'; #else - constexpr char SepChar = ';'; + constexpr char SepChar = ';'; #endif - std::stringstream ss(Paths); - std::string path; - - while (std::getline(ss, path, SepChar)) - { - fname = path + "/" + f; - std::ifstream::open(fname); - if (is_open()) - return; - } - } + std::stringstream ss(Paths); + std::string path; + + while (std::getline(ss, path, SepChar)) { + fname = path + "/" + f; + std::ifstream::open(fname); + if (is_open()) return; + } + } - // Memory map the file and check it. - uint8_t* map(void** baseAddress, uint64_t* mapping, TBType type) { - if (is_open()) - close(); // Need to re-open to get native file descriptor + // Memory map the file and check it. + uint8_t* map(void** baseAddress, uint64_t* mapping, TBType type) { + if (is_open()) close(); // Need to re-open to get native file descriptor #ifndef _WIN32 - struct stat statbuf; - int fd = ::open(fname.c_str(), O_RDONLY); + struct stat statbuf; + int fd = ::open(fname.c_str(), O_RDONLY); - if (fd == -1) - return *baseAddress = nullptr, nullptr; + if (fd == -1) return *baseAddress = nullptr, nullptr; - fstat(fd, &statbuf); + fstat(fd, &statbuf); - if (statbuf.st_size % 64 != 16) - { - std::cerr << "Corrupt tablebase file " << fname << std::endl; - exit(EXIT_FAILURE); - } + if (statbuf.st_size % 64 != 16) { + std::cerr << "Corrupt tablebase file " << fname << std::endl; + exit(EXIT_FAILURE); + } - *mapping = statbuf.st_size; - *baseAddress = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_SHARED, fd, 0); -#if defined(MADV_RANDOM) - madvise(*baseAddress, statbuf.st_size, MADV_RANDOM); -#endif - ::close(fd); + *mapping = statbuf.st_size; + *baseAddress = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_SHARED, fd, 0); + #if defined(MADV_RANDOM) + madvise(*baseAddress, statbuf.st_size, MADV_RANDOM); + #endif + ::close(fd); - if (*baseAddress == MAP_FAILED) - { - std::cerr << "Could not mmap() " << fname << std::endl; - exit(EXIT_FAILURE); - } + if (*baseAddress == MAP_FAILED) { + std::cerr << "Could not mmap() " << fname << std::endl; + exit(EXIT_FAILURE); + } #else - // Note FILE_FLAG_RANDOM_ACCESS is only a hint to Windows and as such may get ignored. - HANDLE fd = CreateFileA(fname.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, - OPEN_EXISTING, FILE_FLAG_RANDOM_ACCESS, nullptr); + // Note FILE_FLAG_RANDOM_ACCESS is only a hint to Windows and as such may get ignored. + HANDLE fd = CreateFileA(fname.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, + OPEN_EXISTING, FILE_FLAG_RANDOM_ACCESS, nullptr); - if (fd == INVALID_HANDLE_VALUE) - return *baseAddress = nullptr, nullptr; + if (fd == INVALID_HANDLE_VALUE) return *baseAddress = nullptr, nullptr; - DWORD size_high; - DWORD size_low = GetFileSize(fd, &size_high); + DWORD size_high; + DWORD size_low = GetFileSize(fd, &size_high); - if (size_low % 64 != 16) - { - std::cerr << "Corrupt tablebase file " << fname << std::endl; - exit(EXIT_FAILURE); - } + if (size_low % 64 != 16) { + std::cerr << "Corrupt tablebase file " << fname << std::endl; + exit(EXIT_FAILURE); + } - HANDLE mmap = CreateFileMapping(fd, nullptr, PAGE_READONLY, size_high, size_low, nullptr); - CloseHandle(fd); + HANDLE mmap = + CreateFileMapping(fd, nullptr, PAGE_READONLY, size_high, size_low, nullptr); + CloseHandle(fd); - if (!mmap) - { - std::cerr << "CreateFileMapping() failed" << std::endl; - exit(EXIT_FAILURE); - } + if (!mmap) { + std::cerr << "CreateFileMapping() failed" << std::endl; + exit(EXIT_FAILURE); + } - *mapping = (uint64_t)mmap; - *baseAddress = MapViewOfFile(mmap, FILE_MAP_READ, 0, 0, 0); + *mapping = (uint64_t) mmap; + *baseAddress = MapViewOfFile(mmap, FILE_MAP_READ, 0, 0, 0); - if (!*baseAddress) - { - std::cerr << "MapViewOfFile() failed, name = " << fname - << ", error = " << GetLastError() << std::endl; - exit(EXIT_FAILURE); - } + if (!*baseAddress) { + std::cerr << "MapViewOfFile() failed, name = " << fname + << ", error = " << GetLastError() << std::endl; + exit(EXIT_FAILURE); + } #endif - uint8_t* data = (uint8_t*)*baseAddress; + uint8_t* data = (uint8_t*) *baseAddress; - constexpr uint8_t Magics[][4] = { { 0xD7, 0x66, 0x0C, 0xA5 }, - { 0x71, 0xE8, 0x23, 0x5D } }; + constexpr uint8_t Magics[][4] = {{0xD7, 0x66, 0x0C, 0xA5}, + {0x71, 0xE8, 0x23, 0x5D}}; - if (memcmp(data, Magics[type == WDL], 4)) - { - std::cerr << "Corrupted table in file " << fname << std::endl; - unmap(*baseAddress, *mapping); - return *baseAddress = nullptr, nullptr; - } + if (memcmp(data, Magics[type == WDL], 4)) { + std::cerr << "Corrupted table in file " << fname << std::endl; + unmap(*baseAddress, *mapping); + return *baseAddress = nullptr, nullptr; + } - return data + 4; // Skip Magics's header - } + return data + 4; // Skip Magics's header + } - static void unmap(void* baseAddress, uint64_t mapping) { + static void unmap(void* baseAddress, uint64_t mapping) { #ifndef _WIN32 - munmap(baseAddress, mapping); + munmap(baseAddress, mapping); #else - UnmapViewOfFile(baseAddress); - CloseHandle((HANDLE)mapping); + UnmapViewOfFile(baseAddress); + CloseHandle((HANDLE) mapping); #endif - } -}; - -std::string TBFile::Paths; - -// struct PairsData contains low level indexing information to access TB data. -// There are 8, 4 or 2 PairsData records for each TBTable, according to type of -// table and if positions have pawns or not. It is populated at first access. -struct PairsData { - uint8_t flags; // Table flags, see enum TBFlag - uint8_t maxSymLen; // Maximum length in bits of the Huffman symbols - uint8_t minSymLen; // Minimum length in bits of the Huffman symbols - uint32_t blocksNum; // Number of blocks in the TB file - size_t sizeofBlock; // Block size in bytes - size_t span; // About every span values there is a SparseIndex[] entry - Sym* lowestSym; // lowestSym[l] is the symbol of length l with the lowest value - LR* btree; // btree[sym] stores the left and right symbols that expand sym - uint16_t* blockLength; // Number of stored positions (minus one) for each block: 1..65536 - uint32_t blockLengthSize; // Size of blockLength[] table: padded so it's bigger than blocksNum - SparseEntry* sparseIndex; // Partial indices into blockLength[] - size_t sparseIndexSize; // Size of SparseIndex[] table - uint8_t* data; // Start of Huffman compressed data - std::vector base64; // base64[l - min_sym_len] is the 64bit-padded lowest symbol of length l - std::vector symlen; // Number of values (-1) represented by a given Huffman symbol: 1..256 - Piece pieces[TBPIECES]; // Position pieces: the order of pieces defines the groups - uint64_t groupIdx[TBPIECES+1]; // Start index used for the encoding of the group's pieces - int groupLen[TBPIECES+1]; // Number of pieces in a given group: KRKN -> (3, 1) - uint16_t map_idx[4]; // WDLWin, WDLLoss, WDLCursedWin, WDLBlessedLoss (used in DTZ) -}; - -// struct TBTable contains indexing information to access the corresponding TBFile. -// There are 2 types of TBTable, corresponding to a WDL or a DTZ file. TBTable -// is populated at init time but the nested PairsData records are populated at -// first access, when the corresponding file is memory mapped. -template -struct TBTable { - using Ret = typename std::conditional::type; - - static constexpr int Sides = Type == WDL ? 2 : 1; - - std::atomic_bool ready; - void* baseAddress; - uint8_t* map; - uint64_t mapping; - Key key; - Key key2; - int pieceCount; - bool hasPawns; - bool hasUniquePieces; - uint8_t pawnCount[2]; // [Lead color / other color] - PairsData items[Sides][4]; // [wtm / btm][FILE_A..FILE_D or 0] - - PairsData* get(int stm, int f) { - return &items[stm % Sides][hasPawns ? f : 0]; - } + } + }; + + std::string TBFile::Paths; + + // struct PairsData contains low level indexing information to access TB data. + // There are 8, 4 or 2 PairsData records for each TBTable, according to type of + // table and if positions have pawns or not. It is populated at first access. + struct PairsData { + uint8_t flags; // Table flags, see enum TBFlag + uint8_t maxSymLen; // Maximum length in bits of the Huffman symbols + uint8_t minSymLen; // Minimum length in bits of the Huffman symbols + uint32_t blocksNum; // Number of blocks in the TB file + size_t sizeofBlock; // Block size in bytes + size_t span; // About every span values there is a SparseIndex[] entry + Sym* lowestSym; // lowestSym[l] is the symbol of length l with the lowest value + LR* btree; // btree[sym] stores the left and right symbols that expand sym + uint16_t* + blockLength; // Number of stored positions (minus one) for each block: 1..65536 + uint32_t + blockLengthSize; // Size of blockLength[] table: padded so it's bigger than blocksNum + SparseEntry* sparseIndex; // Partial indices into blockLength[] + size_t sparseIndexSize; // Size of SparseIndex[] table + uint8_t* data; // Start of Huffman compressed data + std::vector + base64; // base64[l - min_sym_len] is the 64bit-padded lowest symbol of length l + std::vector + symlen; // Number of values (-1) represented by a given Huffman symbol: 1..256 + Piece pieces[TBPIECES]; // Position pieces: the order of pieces defines the groups + uint64_t + groupIdx[TBPIECES + 1]; // Start index used for the encoding of the group's pieces + int groupLen[TBPIECES + 1]; // Number of pieces in a given group: KRKN -> (3, 1) + uint16_t map_idx[4]; // WDLWin, WDLLoss, WDLCursedWin, WDLBlessedLoss (used in DTZ) + }; + + // struct TBTable contains indexing information to access the corresponding TBFile. + // There are 2 types of TBTable, corresponding to a WDL or a DTZ file. TBTable + // is populated at init time but the nested PairsData records are populated at + // first access, when the corresponding file is memory mapped. + template struct TBTable { + using Ret = typename std::conditional::type; + + static constexpr int Sides = Type == WDL ? 2 : 1; + + std::atomic_bool ready; + void* baseAddress; + uint8_t* map; + uint64_t mapping; + Key key; + Key key2; + int pieceCount; + bool hasPawns; + bool hasUniquePieces; + uint8_t pawnCount[2]; // [Lead color / other color] + PairsData items[Sides][4]; // [wtm / btm][FILE_A..FILE_D or 0] + + PairsData* get(int stm, int f) { return &items[stm % Sides][hasPawns ? f : 0]; } + + TBTable() : ready(false), baseAddress(nullptr) {} + explicit TBTable(const std::string& code); + explicit TBTable(const TBTable& wdl); + + ~TBTable() { + if (baseAddress) TBFile::unmap(baseAddress, mapping); + } + }; - TBTable() : ready(false), baseAddress(nullptr) {} - explicit TBTable(const std::string& code); - explicit TBTable(const TBTable& wdl); + template<> TBTable::TBTable(const std::string& code) : TBTable() { - ~TBTable() { - if (baseAddress) - TBFile::unmap(baseAddress, mapping); - } -}; - -template<> -TBTable::TBTable(const std::string& code) : TBTable() { - - StateInfo st; - Position pos; - - key = pos.set(code, WHITE, &st).material_key(); - pieceCount = pos.count(); - hasPawns = pos.pieces(PAWN); - - hasUniquePieces = false; - for (Color c : { WHITE, BLACK }) - for (PieceType pt = PAWN; pt < KING; ++pt) - if (popcount(pos.pieces(c, pt)) == 1) - hasUniquePieces = true; - - // Set the leading color. In case both sides have pawns the leading color - // is the side with less pawns because this leads to better compression. - bool c = !pos.count(BLACK) - || ( pos.count(WHITE) - && pos.count(BLACK) >= pos.count(WHITE)); - - pawnCount[0] = pos.count(c ? WHITE : BLACK); - pawnCount[1] = pos.count(c ? BLACK : WHITE); - - key2 = pos.set(code, BLACK, &st).material_key(); -} - -template<> -TBTable::TBTable(const TBTable& wdl) : TBTable() { - - // Use the corresponding WDL table to avoid recalculating all from scratch - key = wdl.key; - key2 = wdl.key2; - pieceCount = wdl.pieceCount; - hasPawns = wdl.hasPawns; - hasUniquePieces = wdl.hasUniquePieces; - pawnCount[0] = wdl.pawnCount[0]; - pawnCount[1] = wdl.pawnCount[1]; -} - -// class TBTables creates and keeps ownership of the TBTable objects, one for -// each TB file found. It supports a fast, hash based, table lookup. Populated -// at init time, accessed at probe time. -class TBTables { - - struct Entry - { - Key key; - TBTable* wdl; - TBTable* dtz; - - template - TBTable* get() const { - return (TBTable*)(Type == WDL ? (void*)wdl : (void*)dtz); + StateInfo st; + Position pos; + + key = pos.set(code, WHITE, &st).material_key(); + pieceCount = pos.count(); + hasPawns = pos.pieces(PAWN); + + hasUniquePieces = false; + for (Color c : {WHITE, BLACK}) + for (PieceType pt = PAWN; pt < KING; ++pt) + if (popcount(pos.pieces(c, pt)) == 1) hasUniquePieces = true; + + // Set the leading color. In case both sides have pawns the leading color + // is the side with less pawns because this leads to better compression. + bool c = !pos.count(BLACK) || + (pos.count(WHITE) && pos.count(BLACK) >= pos.count(WHITE)); + + pawnCount[0] = pos.count(c ? WHITE : BLACK); + pawnCount[1] = pos.count(c ? BLACK : WHITE); + + key2 = pos.set(code, BLACK, &st).material_key(); } - }; - static constexpr int Size = 1 << 12; // 4K table, indexed by key's 12 lsb - static constexpr int Overflow = 1; // Number of elements allowed to map to the last bucket + template<> TBTable::TBTable(const TBTable& wdl) : TBTable() { - Entry hashTable[Size + Overflow]; + // Use the corresponding WDL table to avoid recalculating all from scratch + key = wdl.key; + key2 = wdl.key2; + pieceCount = wdl.pieceCount; + hasPawns = wdl.hasPawns; + hasUniquePieces = wdl.hasUniquePieces; + pawnCount[0] = wdl.pawnCount[0]; + pawnCount[1] = wdl.pawnCount[1]; + } - std::deque> wdlTable; - std::deque> dtzTable; + // class TBTables creates and keeps ownership of the TBTable objects, one for + // each TB file found. It supports a fast, hash based, table lookup. Populated + // at init time, accessed at probe time. + class TBTables { - void insert(Key key, TBTable* wdl, TBTable* dtz) { - uint32_t homeBucket = (uint32_t)key & (Size - 1); - Entry entry{ key, wdl, dtz }; + struct Entry { + Key key; + TBTable* wdl; + TBTable* dtz; - // Ensure last element is empty to avoid overflow when looking up - for (uint32_t bucket = homeBucket; bucket < Size + Overflow - 1; ++bucket) { - Key otherKey = hashTable[bucket].key; - if (otherKey == key || !hashTable[bucket].get()) { - hashTable[bucket] = entry; - return; + template TBTable* get() const { + return (TBTable*) (Type == WDL ? (void*) wdl : (void*) dtz); + } + }; + + static constexpr int Size = 1 << 12; // 4K table, indexed by key's 12 lsb + static constexpr int Overflow = + 1; // Number of elements allowed to map to the last bucket + + Entry hashTable[Size + Overflow]; + + std::deque> wdlTable; + std::deque> dtzTable; + + void insert(Key key, TBTable* wdl, TBTable* dtz) { + uint32_t homeBucket = (uint32_t) key & (Size - 1); + Entry entry{key, wdl, dtz}; + + // Ensure last element is empty to avoid overflow when looking up + for (uint32_t bucket = homeBucket; bucket < Size + Overflow - 1; ++bucket) { + Key otherKey = hashTable[bucket].key; + if (otherKey == key || !hashTable[bucket].get()) { + hashTable[bucket] = entry; + return; + } + + // Robin Hood hashing: If we've probed for longer than this element, + // insert here and search for a new spot for the other element instead. + uint32_t otherHomeBucket = (uint32_t) otherKey & (Size - 1); + if (otherHomeBucket > homeBucket) { + std::swap(entry, hashTable[bucket]); + key = otherKey; + homeBucket = otherHomeBucket; + } + } + std::cerr << "TB hash table size too low!" << std::endl; + exit(EXIT_FAILURE); } - // Robin Hood hashing: If we've probed for longer than this element, - // insert here and search for a new spot for the other element instead. - uint32_t otherHomeBucket = (uint32_t)otherKey & (Size - 1); - if (otherHomeBucket > homeBucket) { - std::swap(entry, hashTable[bucket]); - key = otherKey; - homeBucket = otherHomeBucket; + public: + template TBTable* get(Key key) { + for (const Entry* entry = &hashTable[(uint32_t) key & (Size - 1)];; ++entry) { + if (entry->key == key || !entry->get()) return entry->get(); + } } - } - std::cerr << "TB hash table size too low!" << std::endl; - exit(EXIT_FAILURE); - } -public: - template - TBTable* get(Key key) { - for (const Entry* entry = &hashTable[(uint32_t)key & (Size - 1)]; ; ++entry) { - if (entry->key == key || !entry->get()) - return entry->get(); - } - } + void clear() { + memset(hashTable, 0, sizeof(hashTable)); + wdlTable.clear(); + dtzTable.clear(); + } + size_t size() const { return wdlTable.size(); } + void add(const std::vector& pieces); + }; - void clear() { - memset(hashTable, 0, sizeof(hashTable)); - wdlTable.clear(); - dtzTable.clear(); - } - size_t size() const { return wdlTable.size(); } - void add(const std::vector& pieces); -}; - -TBTables TBTables; - -// If the corresponding file exists two new objects TBTable and TBTable -// are created and added to the lists and hash table. Called at init time. -void TBTables::add(const std::vector& pieces) { - - std::string code; - - for (PieceType pt : pieces) - code += PieceToChar[pt]; - - TBFile file(code.insert(code.find('K', 1), "v") + ".rtbw"); // KRK -> KRvK - - if (!file.is_open()) // Only WDL file is checked - return; - - file.close(); - - MaxCardinality = std::max((int)pieces.size(), MaxCardinality); - - wdlTable.emplace_back(code); - dtzTable.emplace_back(wdlTable.back()); - - // Insert into the hash keys for both colors: KRvK with KR white and black - insert(wdlTable.back().key , &wdlTable.back(), &dtzTable.back()); - insert(wdlTable.back().key2, &wdlTable.back(), &dtzTable.back()); -} - -// TB tables are compressed with canonical Huffman code. The compressed data is divided into -// blocks of size d->sizeofBlock, and each block stores a variable number of symbols. -// Each symbol represents either a WDL or a (remapped) DTZ value, or a pair of other symbols -// (recursively). If you keep expanding the symbols in a block, you end up with up to 65536 -// WDL or DTZ values. Each symbol represents up to 256 values and will correspond after -// Huffman coding to at least 1 bit. So a block of 32 bytes corresponds to at most -// 32 x 8 x 256 = 65536 values. This maximum is only reached for tables that consist mostly -// of draws or mostly of wins, but such tables are actually quite common. In principle, the -// blocks in WDL tables are 64 bytes long (and will be aligned on cache lines). But for -// mostly-draw or mostly-win tables this can leave many 64-byte blocks only half-filled, so -// in such cases blocks are 32 bytes long. The blocks of DTZ tables are up to 1024 bytes long. -// The generator picks the size that leads to the smallest table. The "book" of symbols and -// Huffman codes is the same for all blocks in the table. A non-symmetric pawnless TB file -// will have one table for wtm and one for btm, a TB file with pawns will have tables per -// file a,b,c,d also in this case one set for wtm and one for btm. -int decompress_pairs(PairsData* d, uint64_t idx) { - - // Special case where all table positions store the same value - if (d->flags & TBFlag::SingleValue) - return d->minSymLen; - - // First we need to locate the right block that stores the value at index "idx". - // Because each block n stores blockLength[n] + 1 values, the index i of the block - // that contains the value at position idx is: - // - // for (i = -1, sum = 0; sum <= idx; i++) - // sum += blockLength[i + 1] + 1; - // - // This can be slow, so we use SparseIndex[] populated with a set of SparseEntry that - // point to known indices into blockLength[]. Namely SparseIndex[k] is a SparseEntry - // that stores the blockLength[] index and the offset within that block of the value - // with index I(k), where: - // - // I(k) = k * d->span + d->span / 2 (1) + TBTables TBTables; - // First step is to get the 'k' of the I(k) nearest to our idx, using definition (1) - uint32_t k = uint32_t(idx / d->span); + // If the corresponding file exists two new objects TBTable and TBTable + // are created and added to the lists and hash table. Called at init time. + void TBTables::add(const std::vector& pieces) { - // Then we read the corresponding SparseIndex[] entry - uint32_t block = number(&d->sparseIndex[k].block); - int offset = number(&d->sparseIndex[k].offset); + std::string code; - // Now compute the difference idx - I(k). From definition of k we know that - // - // idx = k * d->span + idx % d->span (2) - // - // So from (1) and (2) we can compute idx - I(K): - int diff = idx % d->span - d->span / 2; - - // Sum the above to offset to find the offset corresponding to our idx - offset += diff; - - // Move to previous/next block, until we reach the correct block that contains idx, - // that is when 0 <= offset <= d->blockLength[block] - while (offset < 0) - offset += d->blockLength[--block] + 1; - - while (offset > d->blockLength[block]) - offset -= d->blockLength[block++] + 1; - - // Finally, we find the start address of our block of canonical Huffman symbols - uint32_t* ptr = (uint32_t*)(d->data + ((uint64_t)block * d->sizeofBlock)); - - // Read the first 64 bits in our block, this is a (truncated) sequence of - // unknown number of symbols of unknown length but we know the first one - // is at the beginning of this 64 bits sequence. - uint64_t buf64 = number(ptr); ptr += 2; - int buf64Size = 64; - Sym sym; - - while (true) - { - int len = 0; // This is the symbol length - d->min_sym_len - - // Now get the symbol length. For any symbol s64 of length l right-padded - // to 64 bits we know that d->base64[l-1] >= s64 >= d->base64[l] so we - // can find the symbol length iterating through base64[]. - while (buf64 < d->base64[len]) - ++len; - - // All the symbols of a given length are consecutive integers (numerical - // sequence property), so we can compute the offset of our symbol of - // length len, stored at the beginning of buf64. - sym = Sym((buf64 - d->base64[len]) >> (64 - len - d->minSymLen)); - - // Now add the value of the lowest symbol of length len to get our symbol - sym += number(&d->lowestSym[len]); - - // If our offset is within the number of values represented by symbol sym - // we are done... - if (offset < d->symlen[sym] + 1) - break; - - // ...otherwise update the offset and continue to iterate - offset -= d->symlen[sym] + 1; - len += d->minSymLen; // Get the real length - buf64 <<= len; // Consume the just processed symbol - buf64Size -= len; - - if (buf64Size <= 32) { // Refill the buffer - buf64Size += 32; - buf64 |= (uint64_t)number(ptr++) << (64 - buf64Size); - } - } + for (PieceType pt : pieces) code += PieceToChar[pt]; + + TBFile file(code.insert(code.find('K', 1), "v") + ".rtbw"); // KRK -> KRvK - // Ok, now we have our symbol that expands into d->symlen[sym] + 1 symbols. - // We binary-search for our value recursively expanding into the left and - // right child symbols until we reach a leaf node where symlen[sym] + 1 == 1 - // that will store the value we need. - while (d->symlen[sym]) - { - Sym left = d->btree[sym].get(); - - // If a symbol contains 36 sub-symbols (d->symlen[sym] + 1 = 36) and - // expands in a pair (d->symlen[left] = 23, d->symlen[right] = 11), then - // we know that, for instance the ten-th value (offset = 10) will be on - // the left side because in Recursive Pairing child symbols are adjacent. - if (offset < d->symlen[left] + 1) - sym = left; - else { - offset -= d->symlen[left] + 1; - sym = d->btree[sym].get(); + if (!file.is_open()) // Only WDL file is checked + return; + + file.close(); + + MaxCardinality = std::max((int) pieces.size(), MaxCardinality); + + wdlTable.emplace_back(code); + dtzTable.emplace_back(wdlTable.back()); + + // Insert into the hash keys for both colors: KRvK with KR white and black + insert(wdlTable.back().key, &wdlTable.back(), &dtzTable.back()); + insert(wdlTable.back().key2, &wdlTable.back(), &dtzTable.back()); } - } - return d->btree[sym].get(); -} + // TB tables are compressed with canonical Huffman code. The compressed data is divided into + // blocks of size d->sizeofBlock, and each block stores a variable number of symbols. + // Each symbol represents either a WDL or a (remapped) DTZ value, or a pair of other symbols + // (recursively). If you keep expanding the symbols in a block, you end up with up to 65536 + // WDL or DTZ values. Each symbol represents up to 256 values and will correspond after + // Huffman coding to at least 1 bit. So a block of 32 bytes corresponds to at most + // 32 x 8 x 256 = 65536 values. This maximum is only reached for tables that consist mostly + // of draws or mostly of wins, but such tables are actually quite common. In principle, the + // blocks in WDL tables are 64 bytes long (and will be aligned on cache lines). But for + // mostly-draw or mostly-win tables this can leave many 64-byte blocks only half-filled, so + // in such cases blocks are 32 bytes long. The blocks of DTZ tables are up to 1024 bytes long. + // The generator picks the size that leads to the smallest table. The "book" of symbols and + // Huffman codes is the same for all blocks in the table. A non-symmetric pawnless TB file + // will have one table for wtm and one for btm, a TB file with pawns will have tables per + // file a,b,c,d also in this case one set for wtm and one for btm. + int decompress_pairs(PairsData* d, uint64_t idx) { + + // Special case where all table positions store the same value + if (d->flags & TBFlag::SingleValue) return d->minSymLen; + + // First we need to locate the right block that stores the value at index "idx". + // Because each block n stores blockLength[n] + 1 values, the index i of the block + // that contains the value at position idx is: + // + // for (i = -1, sum = 0; sum <= idx; i++) + // sum += blockLength[i + 1] + 1; + // + // This can be slow, so we use SparseIndex[] populated with a set of SparseEntry that + // point to known indices into blockLength[]. Namely SparseIndex[k] is a SparseEntry + // that stores the blockLength[] index and the offset within that block of the value + // with index I(k), where: + // + // I(k) = k * d->span + d->span / 2 (1) + + // First step is to get the 'k' of the I(k) nearest to our idx, using definition (1) + uint32_t k = uint32_t(idx / d->span); + + // Then we read the corresponding SparseIndex[] entry + uint32_t block = number(&d->sparseIndex[k].block); + int offset = number(&d->sparseIndex[k].offset); + + // Now compute the difference idx - I(k). From definition of k we know that + // + // idx = k * d->span + idx % d->span (2) + // + // So from (1) and (2) we can compute idx - I(K): + int diff = idx % d->span - d->span / 2; + + // Sum the above to offset to find the offset corresponding to our idx + offset += diff; + + // Move to previous/next block, until we reach the correct block that contains idx, + // that is when 0 <= offset <= d->blockLength[block] + while (offset < 0) offset += d->blockLength[--block] + 1; + + while (offset > d->blockLength[block]) offset -= d->blockLength[block++] + 1; + + // Finally, we find the start address of our block of canonical Huffman symbols + uint32_t* ptr = (uint32_t*) (d->data + ((uint64_t) block * d->sizeofBlock)); + + // Read the first 64 bits in our block, this is a (truncated) sequence of + // unknown number of symbols of unknown length but we know the first one + // is at the beginning of this 64 bits sequence. + uint64_t buf64 = number(ptr); + ptr += 2; + int buf64Size = 64; + Sym sym; + + while (true) { + int len = 0; // This is the symbol length - d->min_sym_len + + // Now get the symbol length. For any symbol s64 of length l right-padded + // to 64 bits we know that d->base64[l-1] >= s64 >= d->base64[l] so we + // can find the symbol length iterating through base64[]. + while (buf64 < d->base64[len]) ++len; + + // All the symbols of a given length are consecutive integers (numerical + // sequence property), so we can compute the offset of our symbol of + // length len, stored at the beginning of buf64. + sym = Sym((buf64 - d->base64[len]) >> (64 - len - d->minSymLen)); + + // Now add the value of the lowest symbol of length len to get our symbol + sym += number(&d->lowestSym[len]); + + // If our offset is within the number of values represented by symbol sym + // we are done... + if (offset < d->symlen[sym] + 1) break; + + // ...otherwise update the offset and continue to iterate + offset -= d->symlen[sym] + 1; + len += d->minSymLen; // Get the real length + buf64 <<= len; // Consume the just processed symbol + buf64Size -= len; + + if (buf64Size <= 32) { // Refill the buffer + buf64Size += 32; + buf64 |= (uint64_t) number(ptr++) << (64 - buf64Size); + } + } -bool check_dtz_stm(TBTable*, int, File) { return true; } + // Ok, now we have our symbol that expands into d->symlen[sym] + 1 symbols. + // We binary-search for our value recursively expanding into the left and + // right child symbols until we reach a leaf node where symlen[sym] + 1 == 1 + // that will store the value we need. + while (d->symlen[sym]) { + Sym left = d->btree[sym].get(); + + // If a symbol contains 36 sub-symbols (d->symlen[sym] + 1 = 36) and + // expands in a pair (d->symlen[left] = 23, d->symlen[right] = 11), then + // we know that, for instance the ten-th value (offset = 10) will be on + // the left side because in Recursive Pairing child symbols are adjacent. + if (offset < d->symlen[left] + 1) + sym = left; + else { + offset -= d->symlen[left] + 1; + sym = d->btree[sym].get(); + } + } -bool check_dtz_stm(TBTable* entry, int stm, File f) { + return d->btree[sym].get(); + } - auto flags = entry->get(stm, f)->flags; - return (flags & TBFlag::STM) == stm - || ((entry->key == entry->key2) && !entry->hasPawns); -} + bool check_dtz_stm(TBTable*, int, File) { return true; } -// DTZ scores are sorted by frequency of occurrence and then assigned the -// values 0, 1, 2, ... in order of decreasing frequency. This is done for each -// of the four WDLScore values. The mapping information necessary to reconstruct -// the original values is stored in the TB file and read during map[] init. -WDLScore map_score(TBTable*, File, int value, WDLScore) { return WDLScore(value - 2); } + bool check_dtz_stm(TBTable* entry, int stm, File f) { -int map_score(TBTable* entry, File f, int value, WDLScore wdl) { + auto flags = entry->get(stm, f)->flags; + return (flags & TBFlag::STM) == stm || + ((entry->key == entry->key2) && !entry->hasPawns); + } - constexpr int WDLMap[] = { 1, 3, 0, 2, 0 }; + // DTZ scores are sorted by frequency of occurrence and then assigned the + // values 0, 1, 2, ... in order of decreasing frequency. This is done for each + // of the four WDLScore values. The mapping information necessary to reconstruct + // the original values is stored in the TB file and read during map[] init. + WDLScore map_score(TBTable*, File, int value, WDLScore) { return WDLScore(value - 2); } - auto flags = entry->get(0, f)->flags; + int map_score(TBTable* entry, File f, int value, WDLScore wdl) { - uint8_t* map = entry->map; - uint16_t* idx = entry->get(0, f)->map_idx; - if (flags & TBFlag::Mapped) { - if (flags & TBFlag::Wide) - value = ((uint16_t *)map)[idx[WDLMap[wdl + 2]] + value]; - else - value = map[idx[WDLMap[wdl + 2]] + value]; - } + constexpr int WDLMap[] = {1, 3, 0, 2, 0}; - // DTZ tables store distance to zero in number of moves or plies. We - // want to return plies, so we have convert to plies when needed. - if ( (wdl == WDLWin && !(flags & TBFlag::WinPlies)) - || (wdl == WDLLoss && !(flags & TBFlag::LossPlies)) - || wdl == WDLCursedWin - || wdl == WDLBlessedLoss) - value *= 2; - - return value + 1; -} - -// Compute a unique index out of a position and use it to probe the TB file. To -// encode k pieces of same type and color, first sort the pieces by square in -// ascending order s1 <= s2 <= ... <= sk then compute the unique index as: -// -// idx = Binomial[1][s1] + Binomial[2][s2] + ... + Binomial[k][sk] -// -template -Ret do_probe_table(const Position& pos, T* entry, WDLScore wdl, ProbeState* result) { - - Square squares[TBPIECES]; - Piece pieces[TBPIECES]; - uint64_t idx; - int next = 0, size = 0, leadPawnsCnt = 0; - PairsData* d; - Bitboard b, leadPawns = 0; - File tbFile = FILE_A; - - // A given TB entry like KRK has associated two material keys: KRvk and Kvkr. - // If both sides have the same pieces keys are equal. In this case TB tables - // only store the 'white to move' case, so if the position to lookup has black - // to move, we need to switch the color and flip the squares before to lookup. - bool symmetricBlackToMove = (entry->key == entry->key2 && pos.side_to_move()); - - // TB files are calculated for white as stronger side. For instance we have - // KRvK, not KvKR. A position where stronger side is white will have its - // material key == entry->key, otherwise we have to switch the color and - // flip the squares before to lookup. - bool blackStronger = (pos.material_key() != entry->key); - - int flipColor = (symmetricBlackToMove || blackStronger) * 8; - int flipSquares = (symmetricBlackToMove || blackStronger) * 56; - int stm = (symmetricBlackToMove || blackStronger) ^ pos.side_to_move(); - - // For pawns, TB files store 4 separate tables according if leading pawn is on - // file a, b, c or d after reordering. The leading pawn is the one with maximum - // MapPawns[] value, that is the one most toward the edges and with lowest rank. - if (entry->hasPawns) { - - // In all the 4 tables, pawns are at the beginning of the piece sequence and - // their color is the reference one. So we just pick the first one. - Piece pc = Piece(entry->get(0, 0)->pieces[0] ^ flipColor); - - assert(type_of(pc) == PAWN); - - leadPawns = b = pos.pieces(color_of(pc), PAWN); - do - squares[size++] = pop_lsb(b) ^ flipSquares; - while (b); - - leadPawnsCnt = size; - - std::swap(squares[0], *std::max_element(squares, squares + leadPawnsCnt, pawns_comp)); - - tbFile = File(edge_distance(file_of(squares[0]))); - } + auto flags = entry->get(0, f)->flags; - // DTZ tables are one-sided, i.e. they store positions only for white to - // move or only for black to move, so check for side to move to be stm, - // early exit otherwise. - if (!check_dtz_stm(entry, stm, tbFile)) - return *result = CHANGE_STM, Ret(); - - // Now we are ready to get all the position pieces (but the lead pawns) and - // directly map them to the correct color and square. - b = pos.pieces() ^ leadPawns; - do { - Square s = pop_lsb(b); - squares[size] = s ^ flipSquares; - pieces[size++] = Piece(pos.piece_on(s) ^ flipColor); - } while (b); - - assert(size >= 2); - - d = entry->get(stm, tbFile); - - // Then we reorder the pieces to have the same sequence as the one stored - // in pieces[i]: the sequence that ensures the best compression. - for (int i = leadPawnsCnt; i < size - 1; ++i) - for (int j = i + 1; j < size; ++j) - if (d->pieces[i] == pieces[j]) - { - std::swap(pieces[i], pieces[j]); - std::swap(squares[i], squares[j]); - break; + uint8_t* map = entry->map; + uint16_t* idx = entry->get(0, f)->map_idx; + if (flags & TBFlag::Mapped) { + if (flags & TBFlag::Wide) + value = ((uint16_t*) map)[idx[WDLMap[wdl + 2]] + value]; + else + value = map[idx[WDLMap[wdl + 2]] + value]; } - // Now we map again the squares so that the square of the lead piece is in - // the triangle A1-D1-D4. - if (file_of(squares[0]) > FILE_D) - for (int i = 0; i < size; ++i) - squares[i] = flip_file(squares[i]); + // DTZ tables store distance to zero in number of moves or plies. We + // want to return plies, so we have convert to plies when needed. + if ((wdl == WDLWin && !(flags & TBFlag::WinPlies)) || + (wdl == WDLLoss && !(flags & TBFlag::LossPlies)) || wdl == WDLCursedWin || + wdl == WDLBlessedLoss) + value *= 2; - // Encode leading pawns starting with the one with minimum MapPawns[] and - // proceeding in ascending order. - if (entry->hasPawns) { - idx = LeadPawnIdx[leadPawnsCnt][squares[0]]; + return value + 1; + } - std::stable_sort(squares + 1, squares + leadPawnsCnt, pawns_comp); + // Compute a unique index out of a position and use it to probe the TB file. To + // encode k pieces of same type and color, first sort the pieces by square in + // ascending order s1 <= s2 <= ... <= sk then compute the unique index as: + // + // idx = Binomial[1][s1] + Binomial[2][s2] + ... + Binomial[k][sk] + // + template + Ret do_probe_table(const Position& pos, T* entry, WDLScore wdl, ProbeState* result) { + + Square squares[TBPIECES]; + Piece pieces[TBPIECES]; + uint64_t idx; + int next = 0, size = 0, leadPawnsCnt = 0; + PairsData* d; + Bitboard b, leadPawns = 0; + File tbFile = FILE_A; + + // A given TB entry like KRK has associated two material keys: KRvk and Kvkr. + // If both sides have the same pieces keys are equal. In this case TB tables + // only store the 'white to move' case, so if the position to lookup has black + // to move, we need to switch the color and flip the squares before to lookup. + bool symmetricBlackToMove = (entry->key == entry->key2 && pos.side_to_move()); + + // TB files are calculated for white as stronger side. For instance we have + // KRvK, not KvKR. A position where stronger side is white will have its + // material key == entry->key, otherwise we have to switch the color and + // flip the squares before to lookup. + bool blackStronger = (pos.material_key() != entry->key); + + int flipColor = (symmetricBlackToMove || blackStronger) * 8; + int flipSquares = (symmetricBlackToMove || blackStronger) * 56; + int stm = (symmetricBlackToMove || blackStronger) ^ pos.side_to_move(); + + // For pawns, TB files store 4 separate tables according if leading pawn is on + // file a, b, c or d after reordering. The leading pawn is the one with maximum + // MapPawns[] value, that is the one most toward the edges and with lowest rank. + if (entry->hasPawns) { + + // In all the 4 tables, pawns are at the beginning of the piece sequence and + // their color is the reference one. So we just pick the first one. + Piece pc = Piece(entry->get(0, 0)->pieces[0] ^ flipColor); + + assert(type_of(pc) == PAWN); + + leadPawns = b = pos.pieces(color_of(pc), PAWN); + do squares[size++] = pop_lsb(b) ^ flipSquares; + while (b); + + leadPawnsCnt = size; + + std::swap(squares[0], + *std::max_element(squares, squares + leadPawnsCnt, pawns_comp)); + + tbFile = File(edge_distance(file_of(squares[0]))); + } - for (int i = 1; i < leadPawnsCnt; ++i) - idx += Binomial[i][MapPawns[squares[i]]]; + // DTZ tables are one-sided, i.e. they store positions only for white to + // move or only for black to move, so check for side to move to be stm, + // early exit otherwise. + if (!check_dtz_stm(entry, stm, tbFile)) return *result = CHANGE_STM, Ret(); + + // Now we are ready to get all the position pieces (but the lead pawns) and + // directly map them to the correct color and square. + b = pos.pieces() ^ leadPawns; + do { + Square s = pop_lsb(b); + squares[size] = s ^ flipSquares; + pieces[size++] = Piece(pos.piece_on(s) ^ flipColor); + } while (b); + + assert(size >= 2); + + d = entry->get(stm, tbFile); + + // Then we reorder the pieces to have the same sequence as the one stored + // in pieces[i]: the sequence that ensures the best compression. + for (int i = leadPawnsCnt; i < size - 1; ++i) + for (int j = i + 1; j < size; ++j) + if (d->pieces[i] == pieces[j]) { + std::swap(pieces[i], pieces[j]); + std::swap(squares[i], squares[j]); + break; + } + + // Now we map again the squares so that the square of the lead piece is in + // the triangle A1-D1-D4. + if (file_of(squares[0]) > FILE_D) + for (int i = 0; i < size; ++i) squares[i] = flip_file(squares[i]); + + // Encode leading pawns starting with the one with minimum MapPawns[] and + // proceeding in ascending order. + if (entry->hasPawns) { + idx = LeadPawnIdx[leadPawnsCnt][squares[0]]; + + std::stable_sort(squares + 1, squares + leadPawnsCnt, pawns_comp); + + for (int i = 1; i < leadPawnsCnt; ++i) idx += Binomial[i][MapPawns[squares[i]]]; + + goto encode_remaining; // With pawns we have finished special treatments + } - goto encode_remaining; // With pawns we have finished special treatments - } + // In positions without pawns, we further flip the squares to ensure leading + // piece is below RANK_5. + if (rank_of(squares[0]) > RANK_4) + for (int i = 0; i < size; ++i) squares[i] = flip_rank(squares[i]); - // In positions without pawns, we further flip the squares to ensure leading - // piece is below RANK_5. - if (rank_of(squares[0]) > RANK_4) - for (int i = 0; i < size; ++i) - squares[i] = flip_rank(squares[i]); - - // Look for the first piece of the leading group not on the A1-D4 diagonal - // and ensure it is mapped below the diagonal. - for (int i = 0; i < d->groupLen[0]; ++i) { - if (!off_A1H8(squares[i])) - continue; - - if (off_A1H8(squares[i]) > 0) // A1-H8 diagonal flip: SQ_A3 -> SQ_C1 - for (int j = i; j < size; ++j) - squares[j] = Square(((squares[j] >> 3) | (squares[j] << 3)) & 63); - break; - } + // Look for the first piece of the leading group not on the A1-D4 diagonal + // and ensure it is mapped below the diagonal. + for (int i = 0; i < d->groupLen[0]; ++i) { + if (!off_A1H8(squares[i])) continue; - // Encode the leading group. - // - // Suppose we have KRvK. Let's say the pieces are on square numbers wK, wR - // and bK (each 0...63). The simplest way to map this position to an index - // is like this: - // - // index = wK * 64 * 64 + wR * 64 + bK; - // - // But this way the TB is going to have 64*64*64 = 262144 positions, with - // lots of positions being equivalent (because they are mirrors of each - // other) and lots of positions being invalid (two pieces on one square, - // adjacent kings, etc.). - // Usually the first step is to take the wK and bK together. There are just - // 462 ways legal and not-mirrored ways to place the wK and bK on the board. - // Once we have placed the wK and bK, there are 62 squares left for the wR - // Mapping its square from 0..63 to available squares 0..61 can be done like: - // - // wR -= (wR > wK) + (wR > bK); - // - // In words: if wR "comes later" than wK, we deduct 1, and the same if wR - // "comes later" than bK. In case of two same pieces like KRRvK we want to - // place the two Rs "together". If we have 62 squares left, we can place two - // Rs "together" in 62 * 61 / 2 ways (we divide by 2 because rooks can be - // swapped and still get the same position.) - // - // In case we have at least 3 unique pieces (included kings) we encode them - // together. - if (entry->hasUniquePieces) { - - int adjust1 = squares[1] > squares[0]; - int adjust2 = (squares[2] > squares[0]) + (squares[2] > squares[1]); - - // First piece is below a1-h8 diagonal. MapA1D1D4[] maps the b1-d1-d3 - // triangle to 0...5. There are 63 squares for second piece and and 62 - // (mapped to 0...61) for the third. - if (off_A1H8(squares[0])) - idx = ( MapA1D1D4[squares[0]] * 63 - + (squares[1] - adjust1)) * 62 - + squares[2] - adjust2; - - // First piece is on a1-h8 diagonal, second below: map this occurrence to - // 6 to differentiate from the above case, rank_of() maps a1-d4 diagonal - // to 0...3 and finally MapB1H1H7[] maps the b1-h1-h7 triangle to 0..27. - else if (off_A1H8(squares[1])) - idx = ( 6 * 63 + rank_of(squares[0]) * 28 - + MapB1H1H7[squares[1]]) * 62 - + squares[2] - adjust2; - - // First two pieces are on a1-h8 diagonal, third below - else if (off_A1H8(squares[2])) - idx = 6 * 63 * 62 + 4 * 28 * 62 - + rank_of(squares[0]) * 7 * 28 - + (rank_of(squares[1]) - adjust1) * 28 - + MapB1H1H7[squares[2]]; - - // All 3 pieces on the diagonal a1-h8 - else - idx = 6 * 63 * 62 + 4 * 28 * 62 + 4 * 7 * 28 - + rank_of(squares[0]) * 7 * 6 - + (rank_of(squares[1]) - adjust1) * 6 - + (rank_of(squares[2]) - adjust2); - } else - // We don't have at least 3 unique pieces, like in KRRvKBB, just map - // the kings. - idx = MapKK[MapA1D1D4[squares[0]]][squares[1]]; + if (off_A1H8(squares[i]) > 0) // A1-H8 diagonal flip: SQ_A3 -> SQ_C1 + for (int j = i; j < size; ++j) + squares[j] = Square(((squares[j] >> 3) | (squares[j] << 3)) & 63); + break; + } + + // Encode the leading group. + // + // Suppose we have KRvK. Let's say the pieces are on square numbers wK, wR + // and bK (each 0...63). The simplest way to map this position to an index + // is like this: + // + // index = wK * 64 * 64 + wR * 64 + bK; + // + // But this way the TB is going to have 64*64*64 = 262144 positions, with + // lots of positions being equivalent (because they are mirrors of each + // other) and lots of positions being invalid (two pieces on one square, + // adjacent kings, etc.). + // Usually the first step is to take the wK and bK together. There are just + // 462 ways legal and not-mirrored ways to place the wK and bK on the board. + // Once we have placed the wK and bK, there are 62 squares left for the wR + // Mapping its square from 0..63 to available squares 0..61 can be done like: + // + // wR -= (wR > wK) + (wR > bK); + // + // In words: if wR "comes later" than wK, we deduct 1, and the same if wR + // "comes later" than bK. In case of two same pieces like KRRvK we want to + // place the two Rs "together". If we have 62 squares left, we can place two + // Rs "together" in 62 * 61 / 2 ways (we divide by 2 because rooks can be + // swapped and still get the same position.) + // + // In case we have at least 3 unique pieces (included kings) we encode them + // together. + if (entry->hasUniquePieces) { + + int adjust1 = squares[1] > squares[0]; + int adjust2 = (squares[2] > squares[0]) + (squares[2] > squares[1]); + + // First piece is below a1-h8 diagonal. MapA1D1D4[] maps the b1-d1-d3 + // triangle to 0...5. There are 63 squares for second piece and and 62 + // (mapped to 0...61) for the third. + if (off_A1H8(squares[0])) + idx = (MapA1D1D4[squares[0]] * 63 + (squares[1] - adjust1)) * 62 + squares[2] - + adjust2; + + // First piece is on a1-h8 diagonal, second below: map this occurrence to + // 6 to differentiate from the above case, rank_of() maps a1-d4 diagonal + // to 0...3 and finally MapB1H1H7[] maps the b1-h1-h7 triangle to 0..27. + else if (off_A1H8(squares[1])) + idx = (6 * 63 + rank_of(squares[0]) * 28 + MapB1H1H7[squares[1]]) * 62 + + squares[2] - adjust2; + + // First two pieces are on a1-h8 diagonal, third below + else if (off_A1H8(squares[2])) + idx = 6 * 63 * 62 + 4 * 28 * 62 + rank_of(squares[0]) * 7 * 28 + + (rank_of(squares[1]) - adjust1) * 28 + MapB1H1H7[squares[2]]; + + // All 3 pieces on the diagonal a1-h8 + else + idx = 6 * 63 * 62 + 4 * 28 * 62 + 4 * 7 * 28 + rank_of(squares[0]) * 7 * 6 + + (rank_of(squares[1]) - adjust1) * 6 + (rank_of(squares[2]) - adjust2); + } else + // We don't have at least 3 unique pieces, like in KRRvKBB, just map + // the kings. + idx = MapKK[MapA1D1D4[squares[0]]][squares[1]]; encode_remaining: - idx *= d->groupIdx[0]; - Square* groupSq = squares + d->groupLen[0]; - - // Encode remaining pawns then pieces according to square, in ascending order - bool remainingPawns = entry->hasPawns && entry->pawnCount[1]; - - while (d->groupLen[++next]) - { - std::stable_sort(groupSq, groupSq + d->groupLen[next]); - uint64_t n = 0; - - // Map down a square if "comes later" than a square in the previous - // groups (similar to what done earlier for leading group pieces). - for (int i = 0; i < d->groupLen[next]; ++i) - { - auto f = [&](Square s) { return groupSq[i] > s; }; - auto adjust = std::count_if(squares, groupSq, f); - n += Binomial[i + 1][groupSq[i] - adjust - 8 * remainingPawns]; - } + idx *= d->groupIdx[0]; + Square* groupSq = squares + d->groupLen[0]; + + // Encode remaining pawns then pieces according to square, in ascending order + bool remainingPawns = entry->hasPawns && entry->pawnCount[1]; + + while (d->groupLen[++next]) { + std::stable_sort(groupSq, groupSq + d->groupLen[next]); + uint64_t n = 0; + + // Map down a square if "comes later" than a square in the previous + // groups (similar to what done earlier for leading group pieces). + for (int i = 0; i < d->groupLen[next]; ++i) { + auto f = [&](Square s) { return groupSq[i] > s; }; + auto adjust = std::count_if(squares, groupSq, f); + n += Binomial[i + 1][groupSq[i] - adjust - 8 * remainingPawns]; + } - remainingPawns = false; - idx += n * d->groupIdx[next]; - groupSq += d->groupLen[next]; - } + remainingPawns = false; + idx += n * d->groupIdx[next]; + groupSq += d->groupLen[next]; + } - // Now that we have the index, decompress the pair and get the score - return map_score(entry, tbFile, decompress_pairs(d, idx), wdl); -} - -// Group together pieces that will be encoded together. The general rule is that -// a group contains pieces of same type and color. The exception is the leading -// group that, in case of positions without pawns, can be formed by 3 different -// pieces (default) or by the king pair when there is not a unique piece apart -// from the kings. When there are pawns, pawns are always first in pieces[]. -// -// As example KRKN -> KRK + N, KNNK -> KK + NN, KPPKP -> P + PP + K + K -// -// The actual grouping depends on the TB generator and can be inferred from the -// sequence of pieces in piece[] array. -template -void set_groups(T& e, PairsData* d, int order[], File f) { - - int n = 0, firstLen = e.hasPawns ? 0 : e.hasUniquePieces ? 3 : 2; - d->groupLen[n] = 1; - - // Number of pieces per group is stored in groupLen[], for instance in KRKN - // the encoder will default on '111', so groupLen[] will be (3, 1). - for (int i = 1; i < e.pieceCount; ++i) - if (--firstLen > 0 || d->pieces[i] == d->pieces[i - 1]) - d->groupLen[n]++; - else - d->groupLen[++n] = 1; - - d->groupLen[++n] = 0; // Zero-terminated - - // The sequence in pieces[] defines the groups, but not the order in which - // they are encoded. If the pieces in a group g can be combined on the board - // in N(g) different ways, then the position encoding will be of the form: - // - // g1 * N(g2) * N(g3) + g2 * N(g3) + g3 - // - // This ensures unique encoding for the whole position. The order of the - // groups is a per-table parameter and could not follow the canonical leading - // pawns/pieces -> remaining pawns -> remaining pieces. In particular the - // first group is at order[0] position and the remaining pawns, when present, - // are at order[1] position. - bool pp = e.hasPawns && e.pawnCount[1]; // Pawns on both sides - int next = pp ? 2 : 1; - int freeSquares = 64 - d->groupLen[0] - (pp ? d->groupLen[1] : 0); - uint64_t idx = 1; - - for (int k = 0; next < n || k == order[0] || k == order[1]; ++k) - if (k == order[0]) // Leading pawns or pieces - { - d->groupIdx[0] = idx; - idx *= e.hasPawns ? LeadPawnsSize[d->groupLen[0]][f] - : e.hasUniquePieces ? 31332 : 462; - } - else if (k == order[1]) // Remaining pawns - { - d->groupIdx[1] = idx; - idx *= Binomial[d->groupLen[1]][48 - d->groupLen[0]]; - } - else // Remaining pieces - { - d->groupIdx[next] = idx; - idx *= Binomial[d->groupLen[next]][freeSquares]; - freeSquares -= d->groupLen[next++]; + // Now that we have the index, decompress the pair and get the score + return map_score(entry, tbFile, decompress_pairs(d, idx), wdl); } - d->groupIdx[n] = idx; -} + // Group together pieces that will be encoded together. The general rule is that + // a group contains pieces of same type and color. The exception is the leading + // group that, in case of positions without pawns, can be formed by 3 different + // pieces (default) or by the king pair when there is not a unique piece apart + // from the kings. When there are pawns, pawns are always first in pieces[]. + // + // As example KRKN -> KRK + N, KNNK -> KK + NN, KPPKP -> P + PP + K + K + // + // The actual grouping depends on the TB generator and can be inferred from the + // sequence of pieces in piece[] array. + template void set_groups(T& e, PairsData* d, int order[], File f) { + + int n = 0, firstLen = e.hasPawns ? 0 : e.hasUniquePieces ? 3 : 2; + d->groupLen[n] = 1; + + // Number of pieces per group is stored in groupLen[], for instance in KRKN + // the encoder will default on '111', so groupLen[] will be (3, 1). + for (int i = 1; i < e.pieceCount; ++i) + if (--firstLen > 0 || d->pieces[i] == d->pieces[i - 1]) + d->groupLen[n]++; + else + d->groupLen[++n] = 1; + + d->groupLen[++n] = 0; // Zero-terminated + + // The sequence in pieces[] defines the groups, but not the order in which + // they are encoded. If the pieces in a group g can be combined on the board + // in N(g) different ways, then the position encoding will be of the form: + // + // g1 * N(g2) * N(g3) + g2 * N(g3) + g3 + // + // This ensures unique encoding for the whole position. The order of the + // groups is a per-table parameter and could not follow the canonical leading + // pawns/pieces -> remaining pawns -> remaining pieces. In particular the + // first group is at order[0] position and the remaining pawns, when present, + // are at order[1] position. + bool pp = e.hasPawns && e.pawnCount[1]; // Pawns on both sides + int next = pp ? 2 : 1; + int freeSquares = 64 - d->groupLen[0] - (pp ? d->groupLen[1] : 0); + uint64_t idx = 1; + + for (int k = 0; next < n || k == order[0] || k == order[1]; ++k) + if (k == order[0]) // Leading pawns or pieces + { + d->groupIdx[0] = idx; + idx *= e.hasPawns ? LeadPawnsSize[d->groupLen[0]][f] : + e.hasUniquePieces ? 31332 : + 462; + } else if (k == order[1]) // Remaining pawns + { + d->groupIdx[1] = idx; + idx *= Binomial[d->groupLen[1]][48 - d->groupLen[0]]; + } else // Remaining pieces + { + d->groupIdx[next] = idx; + idx *= Binomial[d->groupLen[next]][freeSquares]; + freeSquares -= d->groupLen[next++]; + } -// In Recursive Pairing each symbol represents a pair of children symbols. So -// read d->btree[] symbols data and expand each one in his left and right child -// symbol until reaching the leafs that represent the symbol value. -uint8_t set_symlen(PairsData* d, Sym s, std::vector& visited) { + d->groupIdx[n] = idx; + } - visited[s] = true; // We can set it now because tree is acyclic - Sym sr = d->btree[s].get(); + // In Recursive Pairing each symbol represents a pair of children symbols. So + // read d->btree[] symbols data and expand each one in his left and right child + // symbol until reaching the leafs that represent the symbol value. + uint8_t set_symlen(PairsData* d, Sym s, std::vector& visited) { - if (sr == 0xFFF) - return 0; + visited[s] = true; // We can set it now because tree is acyclic + Sym sr = d->btree[s].get(); - Sym sl = d->btree[s].get(); + if (sr == 0xFFF) return 0; - if (!visited[sl]) - d->symlen[sl] = set_symlen(d, sl, visited); + Sym sl = d->btree[s].get(); - if (!visited[sr]) - d->symlen[sr] = set_symlen(d, sr, visited); + if (!visited[sl]) d->symlen[sl] = set_symlen(d, sl, visited); - return d->symlen[sl] + d->symlen[sr] + 1; -} + if (!visited[sr]) d->symlen[sr] = set_symlen(d, sr, visited); -uint8_t* set_sizes(PairsData* d, uint8_t* data) { + return d->symlen[sl] + d->symlen[sr] + 1; + } - d->flags = *data++; + uint8_t* set_sizes(PairsData* d, uint8_t* data) { - if (d->flags & TBFlag::SingleValue) { - d->blocksNum = d->blockLengthSize = 0; - d->span = d->sparseIndexSize = 0; // Broken MSVC zero-init - d->minSymLen = *data++; // Here we store the single value - return data; - } + d->flags = *data++; - // groupLen[] is a zero-terminated list of group lengths, the last groupIdx[] - // element stores the biggest index that is the tb size. - uint64_t tbSize = d->groupIdx[std::find(d->groupLen, d->groupLen + 7, 0) - d->groupLen]; - - d->sizeofBlock = 1ULL << *data++; - d->span = 1ULL << *data++; - d->sparseIndexSize = size_t((tbSize + d->span - 1) / d->span); // Round up - auto padding = number(data++); - d->blocksNum = number(data); data += sizeof(uint32_t); - d->blockLengthSize = d->blocksNum + padding; // Padded to ensure SparseIndex[] - // does not point out of range. - d->maxSymLen = *data++; - d->minSymLen = *data++; - d->lowestSym = (Sym*)data; - d->base64.resize(d->maxSymLen - d->minSymLen + 1); - - // See https://en.wikipedia.org/wiki/Huffman_coding - // The canonical code is ordered such that longer symbols (in terms of - // the number of bits of their Huffman code) have lower numeric value, - // so that d->lowestSym[i] >= d->lowestSym[i+1] (when read as LittleEndian). - // Starting from this we compute a base64[] table indexed by symbol length - // and containing 64 bit values so that d->base64[i] >= d->base64[i+1]. - - // Implementation note: we first cast the unsigned size_t "base64.size()" - // to a signed int "base64_size" variable and then we are able to subtract 2, - // avoiding unsigned overflow warnings. - - int base64_size = static_cast(d->base64.size()); - for (int i = base64_size - 2; i >= 0; --i) { - d->base64[i] = (d->base64[i + 1] + number(&d->lowestSym[i]) - - number(&d->lowestSym[i + 1])) / 2; - - assert(d->base64[i] * 2 >= d->base64[i+1]); - } + if (d->flags & TBFlag::SingleValue) { + d->blocksNum = d->blockLengthSize = 0; + d->span = d->sparseIndexSize = 0; // Broken MSVC zero-init + d->minSymLen = *data++; // Here we store the single value + return data; + } - // Now left-shift by an amount so that d->base64[i] gets shifted 1 bit more - // than d->base64[i+1] and given the above assert condition, we ensure that - // d->base64[i] >= d->base64[i+1]. Moreover for any symbol s64 of length i - // and right-padded to 64 bits holds d->base64[i-1] >= s64 >= d->base64[i]. - for (int i = 0; i < base64_size; ++i) - d->base64[i] <<= 64 - i - d->minSymLen; // Right-padding to 64 bits - - data += base64_size * sizeof(Sym); - d->symlen.resize(number(data)); data += sizeof(uint16_t); - d->btree = (LR*)data; - - // The compression scheme used is "Recursive Pairing", that replaces the most - // frequent adjacent pair of symbols in the source message by a new symbol, - // reevaluating the frequencies of all of the symbol pairs with respect to - // the extended alphabet, and then repeating the process. - // See https://web.archive.org/web/20201106232444/http://www.larsson.dogma.net/dcc99.pdf - std::vector visited(d->symlen.size()); - - for (Sym sym = 0; sym < d->symlen.size(); ++sym) - if (!visited[sym]) - d->symlen[sym] = set_symlen(d, sym, visited); - - return data + d->symlen.size() * sizeof(LR) + (d->symlen.size() & 1); -} - -uint8_t* set_dtz_map(TBTable&, uint8_t* data, File) { return data; } - -uint8_t* set_dtz_map(TBTable& e, uint8_t* data, File maxFile) { - - e.map = data; - - for (File f = FILE_A; f <= maxFile; ++f) { - auto flags = e.get(0, f)->flags; - if (flags & TBFlag::Mapped) { - if (flags & TBFlag::Wide) { - data += (uintptr_t)data & 1; // Word alignment, we may have a mixed table - for (int i = 0; i < 4; ++i) { // Sequence like 3,x,x,x,1,x,0,2,x,x - e.get(0, f)->map_idx[i] = (uint16_t)((uint16_t *)data - (uint16_t *)e.map + 1); - data += 2 * number(data) + 2; - } + // groupLen[] is a zero-terminated list of group lengths, the last groupIdx[] + // element stores the biggest index that is the tb size. + uint64_t tbSize = d->groupIdx[std::find(d->groupLen, d->groupLen + 7, 0) - d->groupLen]; + + d->sizeofBlock = 1ULL << *data++; + d->span = 1ULL << *data++; + d->sparseIndexSize = size_t((tbSize + d->span - 1) / d->span); // Round up + auto padding = number(data++); + d->blocksNum = number(data); + data += sizeof(uint32_t); + d->blockLengthSize = d->blocksNum + padding; // Padded to ensure SparseIndex[] + // does not point out of range. + d->maxSymLen = *data++; + d->minSymLen = *data++; + d->lowestSym = (Sym*) data; + d->base64.resize(d->maxSymLen - d->minSymLen + 1); + + // See https://en.wikipedia.org/wiki/Huffman_coding + // The canonical code is ordered such that longer symbols (in terms of + // the number of bits of their Huffman code) have lower numeric value, + // so that d->lowestSym[i] >= d->lowestSym[i+1] (when read as LittleEndian). + // Starting from this we compute a base64[] table indexed by symbol length + // and containing 64 bit values so that d->base64[i] >= d->base64[i+1]. + + // Implementation note: we first cast the unsigned size_t "base64.size()" + // to a signed int "base64_size" variable and then we are able to subtract 2, + // avoiding unsigned overflow warnings. + + int base64_size = static_cast(d->base64.size()); + for (int i = base64_size - 2; i >= 0; --i) { + d->base64[i] = (d->base64[i + 1] + number(&d->lowestSym[i]) - + number(&d->lowestSym[i + 1])) / + 2; + + assert(d->base64[i] * 2 >= d->base64[i + 1]); } - else { - for (int i = 0; i < 4; ++i) { - e.get(0, f)->map_idx[i] = (uint16_t)(data - e.map + 1); - data += *data + 1; + + // Now left-shift by an amount so that d->base64[i] gets shifted 1 bit more + // than d->base64[i+1] and given the above assert condition, we ensure that + // d->base64[i] >= d->base64[i+1]. Moreover for any symbol s64 of length i + // and right-padded to 64 bits holds d->base64[i-1] >= s64 >= d->base64[i]. + for (int i = 0; i < base64_size; ++i) + d->base64[i] <<= 64 - i - d->minSymLen; // Right-padding to 64 bits + + data += base64_size * sizeof(Sym); + d->symlen.resize(number(data)); + data += sizeof(uint16_t); + d->btree = (LR*) data; + + // The compression scheme used is "Recursive Pairing", that replaces the most + // frequent adjacent pair of symbols in the source message by a new symbol, + // reevaluating the frequencies of all of the symbol pairs with respect to + // the extended alphabet, and then repeating the process. + // See https://web.archive.org/web/20201106232444/http://www.larsson.dogma.net/dcc99.pdf + std::vector visited(d->symlen.size()); + + for (Sym sym = 0; sym < d->symlen.size(); ++sym) + if (!visited[sym]) d->symlen[sym] = set_symlen(d, sym, visited); + + return data + d->symlen.size() * sizeof(LR) + (d->symlen.size() & 1); + } + + uint8_t* set_dtz_map(TBTable&, uint8_t* data, File) { return data; } + + uint8_t* set_dtz_map(TBTable& e, uint8_t* data, File maxFile) { + + e.map = data; + + for (File f = FILE_A; f <= maxFile; ++f) { + auto flags = e.get(0, f)->flags; + if (flags & TBFlag::Mapped) { + if (flags & TBFlag::Wide) { + data += (uintptr_t) data & 1; // Word alignment, we may have a mixed table + for (int i = 0; i < 4; ++i) { // Sequence like 3,x,x,x,1,x,0,2,x,x + e.get(0, f)->map_idx[i] = + (uint16_t) ((uint16_t*) data - (uint16_t*) e.map + 1); + data += 2 * number(data) + 2; + } + } else { + for (int i = 0; i < 4; ++i) { + e.get(0, f)->map_idx[i] = (uint16_t) (data - e.map + 1); + data += *data + 1; + } + } } } + + return data += (uintptr_t) data & 1; // Word alignment } - } - return data += (uintptr_t)data & 1; // Word alignment -} + // Populate entry's PairsData records with data from the just memory mapped file. + // Called at first access. + template void set(T& e, uint8_t* data) { -// Populate entry's PairsData records with data from the just memory mapped file. -// Called at first access. -template -void set(T& e, uint8_t* data) { + PairsData* d; - PairsData* d; + enum { + Split = 1, + HasPawns = 2 + }; - enum { Split = 1, HasPawns = 2 }; + assert(e.hasPawns == bool(*data & HasPawns)); + assert((e.key != e.key2) == bool(*data & Split)); - assert(e.hasPawns == bool(*data & HasPawns)); - assert((e.key != e.key2) == bool(*data & Split)); + data++; // First byte stores flags - data++; // First byte stores flags + const int sides = T::Sides == 2 && (e.key != e.key2) ? 2 : 1; + const File maxFile = e.hasPawns ? FILE_D : FILE_A; - const int sides = T::Sides == 2 && (e.key != e.key2) ? 2 : 1; - const File maxFile = e.hasPawns ? FILE_D : FILE_A; + bool pp = e.hasPawns && e.pawnCount[1]; // Pawns on both sides - bool pp = e.hasPawns && e.pawnCount[1]; // Pawns on both sides + assert(!pp || e.pawnCount[0]); - assert(!pp || e.pawnCount[0]); + for (File f = FILE_A; f <= maxFile; ++f) { - for (File f = FILE_A; f <= maxFile; ++f) { + for (int i = 0; i < sides; i++) *e.get(i, f) = PairsData(); - for (int i = 0; i < sides; i++) - *e.get(i, f) = PairsData(); + int order[][2] = {{*data & 0xF, pp ? *(data + 1) & 0xF : 0xF}, + {*data >> 4, pp ? *(data + 1) >> 4 : 0xF}}; + data += 1 + pp; - int order[][2] = { { *data & 0xF, pp ? *(data + 1) & 0xF : 0xF }, - { *data >> 4, pp ? *(data + 1) >> 4 : 0xF } }; - data += 1 + pp; + for (int k = 0; k < e.pieceCount; ++k, ++data) + for (int i = 0; i < sides; i++) + e.get(i, f)->pieces[k] = Piece(i ? *data >> 4 : *data & 0xF); - for (int k = 0; k < e.pieceCount; ++k, ++data) - for (int i = 0; i < sides; i++) - e.get(i, f)->pieces[k] = Piece(i ? *data >> 4 : *data & 0xF); + for (int i = 0; i < sides; ++i) set_groups(e, e.get(i, f), order[i], f); + } - for (int i = 0; i < sides; ++i) - set_groups(e, e.get(i, f), order[i], f); - } + data += (uintptr_t) data & 1; // Word alignment - data += (uintptr_t)data & 1; // Word alignment + for (File f = FILE_A; f <= maxFile; ++f) + for (int i = 0; i < sides; i++) data = set_sizes(e.get(i, f), data); - for (File f = FILE_A; f <= maxFile; ++f) - for (int i = 0; i < sides; i++) - data = set_sizes(e.get(i, f), data); + data = set_dtz_map(e, data, maxFile); - data = set_dtz_map(e, data, maxFile); + for (File f = FILE_A; f <= maxFile; ++f) + for (int i = 0; i < sides; i++) { + (d = e.get(i, f))->sparseIndex = (SparseEntry*) data; + data += d->sparseIndexSize * sizeof(SparseEntry); + } - for (File f = FILE_A; f <= maxFile; ++f) - for (int i = 0; i < sides; i++) { - (d = e.get(i, f))->sparseIndex = (SparseEntry*)data; - data += d->sparseIndexSize * sizeof(SparseEntry); - } + for (File f = FILE_A; f <= maxFile; ++f) + for (int i = 0; i < sides; i++) { + (d = e.get(i, f))->blockLength = (uint16_t*) data; + data += d->blockLengthSize * sizeof(uint16_t); + } - for (File f = FILE_A; f <= maxFile; ++f) - for (int i = 0; i < sides; i++) { - (d = e.get(i, f))->blockLength = (uint16_t*)data; - data += d->blockLengthSize * sizeof(uint16_t); + for (File f = FILE_A; f <= maxFile; ++f) + for (int i = 0; i < sides; i++) { + data = (uint8_t*) (((uintptr_t) data + 0x3F) & ~0x3F); // 64 byte alignment + (d = e.get(i, f))->data = data; + data += d->blocksNum * d->sizeofBlock; + } } - for (File f = FILE_A; f <= maxFile; ++f) - for (int i = 0; i < sides; i++) { - data = (uint8_t*)(((uintptr_t)data + 0x3F) & ~0x3F); // 64 byte alignment - (d = e.get(i, f))->data = data; - data += d->blocksNum * d->sizeofBlock; - } -} + // If the TB file corresponding to the given position is already memory mapped + // then return its base address, otherwise try to memory map and init it. Called + // at every probe, memory map and init only at first access. Function is thread + // safe and can be called concurrently. + template void* mapped(TBTable& e, const Position& pos) { -// If the TB file corresponding to the given position is already memory mapped -// then return its base address, otherwise try to memory map and init it. Called -// at every probe, memory map and init only at first access. Function is thread -// safe and can be called concurrently. -template -void* mapped(TBTable& e, const Position& pos) { + static std::mutex mutex; - static std::mutex mutex; + // Use 'acquire' to avoid a thread reading 'ready' == true while + // another is still working. (compiler reordering may cause this). + if (e.ready.load(std::memory_order_acquire)) + return e.baseAddress; // Could be nullptr if file does not exist - // Use 'acquire' to avoid a thread reading 'ready' == true while - // another is still working. (compiler reordering may cause this). - if (e.ready.load(std::memory_order_acquire)) - return e.baseAddress; // Could be nullptr if file does not exist + std::scoped_lock lk(mutex); - std::scoped_lock lk(mutex); + if (e.ready.load(std::memory_order_relaxed)) // Recheck under lock + return e.baseAddress; - if (e.ready.load(std::memory_order_relaxed)) // Recheck under lock - return e.baseAddress; + // Pieces strings in decreasing order for each color, like ("KPP","KR") + std::string fname, w, b; + for (PieceType pt = KING; pt >= PAWN; --pt) { + w += std::string(popcount(pos.pieces(WHITE, pt)), PieceToChar[pt]); + b += std::string(popcount(pos.pieces(BLACK, pt)), PieceToChar[pt]); + } - // Pieces strings in decreasing order for each color, like ("KPP","KR") - std::string fname, w, b; - for (PieceType pt = KING; pt >= PAWN; --pt) { - w += std::string(popcount(pos.pieces(WHITE, pt)), PieceToChar[pt]); - b += std::string(popcount(pos.pieces(BLACK, pt)), PieceToChar[pt]); - } + fname = (e.key == pos.material_key() ? w + 'v' + b : b + 'v' + w) + + (Type == WDL ? ".rtbw" : ".rtbz"); - fname = (e.key == pos.material_key() ? w + 'v' + b : b + 'v' + w) - + (Type == WDL ? ".rtbw" : ".rtbz"); + uint8_t* data = TBFile(fname).map(&e.baseAddress, &e.mapping, Type); - uint8_t* data = TBFile(fname).map(&e.baseAddress, &e.mapping, Type); + if (data) set(e, data); - if (data) - set(e, data); + e.ready.store(true, std::memory_order_release); + return e.baseAddress; + } - e.ready.store(true, std::memory_order_release); - return e.baseAddress; -} + template::Ret> + Ret probe_table(const Position& pos, ProbeState* result, WDLScore wdl = WDLDraw) { -template::Ret> -Ret probe_table(const Position& pos, ProbeState* result, WDLScore wdl = WDLDraw) { + if (pos.count() == 2) // KvK + return Ret(WDLDraw); - if (pos.count() == 2) // KvK - return Ret(WDLDraw); + TBTable* entry = TBTables.get(pos.material_key()); - TBTable* entry = TBTables.get(pos.material_key()); + if (!entry || !mapped(*entry, pos)) return *result = FAIL, Ret(); - if (!entry || !mapped(*entry, pos)) - return *result = FAIL, Ret(); + return do_probe_table(pos, entry, wdl, result); + } - return do_probe_table(pos, entry, wdl, result); -} + // For a position where the side to move has a winning capture it is not necessary + // to store a winning value so the generator treats such positions as "don't cares" + // and tries to assign to it a value that improves the compression ratio. Similarly, + // if the side to move has a drawing capture, then the position is at least drawn. + // If the position is won, then the TB needs to store a win value. But if the + // position is drawn, the TB may store a loss value if that is better for compression. + // All of this means that during probing, the engine must look at captures and probe + // their results and must probe the position itself. The "best" result of these + // probes is the correct result for the position. + // DTZ tables do not store values when a following move is a zeroing winning move + // (winning capture or winning pawn move). Also DTZ store wrong values for positions + // where the best move is an ep-move (even if losing). So in all these cases set + // the state to ZEROING_BEST_MOVE. + template WDLScore search(Position& pos, ProbeState* result) { + + WDLScore value, bestValue = WDLLoss; + StateInfo st; + + auto moveList = MoveList(pos); + size_t totalCount = moveList.size(), moveCount = 0; + + for (const Move move : moveList) { + if (!pos.capture(move) && + (!CheckZeroingMoves || type_of(pos.moved_piece(move)) != PAWN)) + continue; + + moveCount++; + + pos.do_move(move, st); + value = -search(pos, result); + pos.undo_move(move); + + if (*result == FAIL) return WDLDraw; + + if (value > bestValue) { + bestValue = value; + + if (value >= WDLWin) { + *result = ZEROING_BEST_MOVE; // Winning DTZ-zeroing move + return value; + } + } + } -// For a position where the side to move has a winning capture it is not necessary -// to store a winning value so the generator treats such positions as "don't cares" -// and tries to assign to it a value that improves the compression ratio. Similarly, -// if the side to move has a drawing capture, then the position is at least drawn. -// If the position is won, then the TB needs to store a win value. But if the -// position is drawn, the TB may store a loss value if that is better for compression. -// All of this means that during probing, the engine must look at captures and probe -// their results and must probe the position itself. The "best" result of these -// probes is the correct result for the position. -// DTZ tables do not store values when a following move is a zeroing winning move -// (winning capture or winning pawn move). Also DTZ store wrong values for positions -// where the best move is an ep-move (even if losing). So in all these cases set -// the state to ZEROING_BEST_MOVE. -template -WDLScore search(Position& pos, ProbeState* result) { + // In case we have already searched all the legal moves we don't have to probe + // the TB because the stored score could be wrong. For instance TB tables + // do not contain information on position with ep rights, so in this case + // the result of probe_wdl_table is wrong. Also in case of only capture + // moves, for instance here 4K3/4q3/6p1/2k5/6p1/8/8/8 w - - 0 7, we have to + // return with ZEROING_BEST_MOVE set. + bool noMoreMoves = (moveCount && moveCount == totalCount); - WDLScore value, bestValue = WDLLoss; - StateInfo st; + if (noMoreMoves) + value = bestValue; + else { + value = probe_table(pos, result); - auto moveList = MoveList(pos); - size_t totalCount = moveList.size(), moveCount = 0; + if (*result == FAIL) return WDLDraw; + } - for (const Move move : moveList) - { - if ( !pos.capture(move) - && (!CheckZeroingMoves || type_of(pos.moved_piece(move)) != PAWN)) - continue; + // DTZ stores a "don't care" value if bestValue is a win + if (bestValue >= value) + return *result = (bestValue > WDLDraw || noMoreMoves ? ZEROING_BEST_MOVE : OK), + bestValue; - moveCount++; + return *result = OK, value; + } - pos.do_move(move, st); - value = -search(pos, result); - pos.undo_move(move); + } // namespace - if (*result == FAIL) - return WDLDraw; - if (value > bestValue) - { - bestValue = value; + /// Tablebases::init() is called at startup and after every change to + /// "SyzygyPath" UCI option to (re)create the various tables. It is not thread + /// safe, nor it needs to be. + void Tablebases::init(const std::string& paths) { - if (value >= WDLWin) - { - *result = ZEROING_BEST_MOVE; // Winning DTZ-zeroing move - return value; - } - } - } + TBTables.clear(); + MaxCardinality = 0; + TBFile::Paths = paths; - // In case we have already searched all the legal moves we don't have to probe - // the TB because the stored score could be wrong. For instance TB tables - // do not contain information on position with ep rights, so in this case - // the result of probe_wdl_table is wrong. Also in case of only capture - // moves, for instance here 4K3/4q3/6p1/2k5/6p1/8/8/8 w - - 0 7, we have to - // return with ZEROING_BEST_MOVE set. - bool noMoreMoves = (moveCount && moveCount == totalCount); - - if (noMoreMoves) - value = bestValue; - else - { - value = probe_table(pos, result); - - if (*result == FAIL) - return WDLDraw; - } + if (paths.empty() || paths == "") return; - // DTZ stores a "don't care" value if bestValue is a win - if (bestValue >= value) - return *result = ( bestValue > WDLDraw - || noMoreMoves ? ZEROING_BEST_MOVE : OK), bestValue; - - return *result = OK, value; -} - -} // namespace - - -/// Tablebases::init() is called at startup and after every change to -/// "SyzygyPath" UCI option to (re)create the various tables. It is not thread -/// safe, nor it needs to be. -void Tablebases::init(const std::string& paths) { - - TBTables.clear(); - MaxCardinality = 0; - TBFile::Paths = paths; - - if (paths.empty() || paths == "") - return; - - // MapB1H1H7[] encodes a square below a1-h8 diagonal to 0..27 - int code = 0; - for (Square s = SQ_A1; s <= SQ_H8; ++s) - if (off_A1H8(s) < 0) - MapB1H1H7[s] = code++; - - // MapA1D1D4[] encodes a square in the a1-d1-d4 triangle to 0..9 - std::vector diagonal; - code = 0; - for (Square s = SQ_A1; s <= SQ_D4; ++s) - if (off_A1H8(s) < 0 && file_of(s) <= FILE_D) - MapA1D1D4[s] = code++; - - else if (!off_A1H8(s) && file_of(s) <= FILE_D) - diagonal.push_back(s); - - // Diagonal squares are encoded as last ones - for (auto s : diagonal) - MapA1D1D4[s] = code++; - - // MapKK[] encodes all the 462 possible legal positions of two kings where - // the first is in the a1-d1-d4 triangle. If the first king is on the a1-d4 - // diagonal, the other one shall not to be above the a1-h8 diagonal. - std::vector> bothOnDiagonal; - code = 0; - for (int idx = 0; idx < 10; idx++) - for (Square s1 = SQ_A1; s1 <= SQ_D4; ++s1) - if (MapA1D1D4[s1] == idx && (idx || s1 == SQ_B1)) // SQ_B1 is mapped to 0 - { - for (Square s2 = SQ_A1; s2 <= SQ_H8; ++s2) - if ((PseudoAttacks[KING][s1] | s1) & s2) - continue; // Illegal position - - else if (!off_A1H8(s1) && off_A1H8(s2) > 0) - continue; // First on diagonal, second above - - else if (!off_A1H8(s1) && !off_A1H8(s2)) - bothOnDiagonal.emplace_back(idx, s2); - - else - MapKK[idx][s2] = code++; - } + // MapB1H1H7[] encodes a square below a1-h8 diagonal to 0..27 + int code = 0; + for (Square s = SQ_A1; s <= SQ_H8; ++s) + if (off_A1H8(s) < 0) MapB1H1H7[s] = code++; + + // MapA1D1D4[] encodes a square in the a1-d1-d4 triangle to 0..9 + std::vector diagonal; + code = 0; + for (Square s = SQ_A1; s <= SQ_D4; ++s) + if (off_A1H8(s) < 0 && file_of(s) <= FILE_D) + MapA1D1D4[s] = code++; + + else if (!off_A1H8(s) && file_of(s) <= FILE_D) + diagonal.push_back(s); - // Legal positions with both kings on diagonal are encoded as last ones - for (auto p : bothOnDiagonal) - MapKK[p.first][p.second] = code++; - - // Binomial[] stores the Binomial Coefficients using Pascal rule. There - // are Binomial[k][n] ways to choose k elements from a set of n elements. - Binomial[0][0] = 1; - - for (int n = 1; n < 64; n++) // Squares - for (int k = 0; k < 6 && k <= n; ++k) // Pieces - Binomial[k][n] = (k > 0 ? Binomial[k - 1][n - 1] : 0) - + (k < n ? Binomial[k ][n - 1] : 0); - - // MapPawns[s] encodes squares a2-h7 to 0..47. This is the number of possible - // available squares when the leading one is in 's'. Moreover the pawn with - // highest MapPawns[] is the leading pawn, the one nearest the edge and, - // among pawns with same file, the one with lowest rank. - int availableSquares = 47; // Available squares when lead pawn is in a2 - - // Init the tables for the encoding of leading pawns group: with 7-men TB we - // can have up to 5 leading pawns (KPPPPPK). - for (int leadPawnsCnt = 1; leadPawnsCnt <= 5; ++leadPawnsCnt) - for (File f = FILE_A; f <= FILE_D; ++f) - { - // Restart the index at every file because TB table is split - // by file, so we can reuse the same index for different files. - int idx = 0; - - // Sum all possible combinations for a given file, starting with - // the leading pawn on rank 2 and increasing the rank. - for (Rank r = RANK_2; r <= RANK_7; ++r) - { - Square sq = make_square(f, r); - - // Compute MapPawns[] at first pass. - // If sq is the leading pawn square, any other pawn cannot be - // below or more toward the edge of sq. There are 47 available - // squares when sq = a2 and reduced by 2 for any rank increase - // due to mirroring: sq == a3 -> no a2, h2, so MapPawns[a3] = 45 - if (leadPawnsCnt == 1) + // Diagonal squares are encoded as last ones + for (auto s : diagonal) MapA1D1D4[s] = code++; + + // MapKK[] encodes all the 462 possible legal positions of two kings where + // the first is in the a1-d1-d4 triangle. If the first king is on the a1-d4 + // diagonal, the other one shall not to be above the a1-h8 diagonal. + std::vector> bothOnDiagonal; + code = 0; + for (int idx = 0; idx < 10; idx++) + for (Square s1 = SQ_A1; s1 <= SQ_D4; ++s1) + if (MapA1D1D4[s1] == idx && (idx || s1 == SQ_B1)) // SQ_B1 is mapped to 0 { - MapPawns[sq] = availableSquares--; - MapPawns[flip_file(sq)] = availableSquares--; + for (Square s2 = SQ_A1; s2 <= SQ_H8; ++s2) + if ((PseudoAttacks[KING][s1] | s1) & s2) + continue; // Illegal position + + else if (!off_A1H8(s1) && off_A1H8(s2) > 0) + continue; // First on diagonal, second above + + else if (!off_A1H8(s1) && !off_A1H8(s2)) + bothOnDiagonal.emplace_back(idx, s2); + + else + MapKK[idx][s2] = code++; + } + + // Legal positions with both kings on diagonal are encoded as last ones + for (auto p : bothOnDiagonal) MapKK[p.first][p.second] = code++; + + // Binomial[] stores the Binomial Coefficients using Pascal rule. There + // are Binomial[k][n] ways to choose k elements from a set of n elements. + Binomial[0][0] = 1; + + for (int n = 1; n < 64; n++) // Squares + for (int k = 0; k < 6 && k <= n; ++k) // Pieces + Binomial[k][n] = + (k > 0 ? Binomial[k - 1][n - 1] : 0) + (k < n ? Binomial[k][n - 1] : 0); + + // MapPawns[s] encodes squares a2-h7 to 0..47. This is the number of possible + // available squares when the leading one is in 's'. Moreover the pawn with + // highest MapPawns[] is the leading pawn, the one nearest the edge and, + // among pawns with same file, the one with lowest rank. + int availableSquares = 47; // Available squares when lead pawn is in a2 + + // Init the tables for the encoding of leading pawns group: with 7-men TB we + // can have up to 5 leading pawns (KPPPPPK). + for (int leadPawnsCnt = 1; leadPawnsCnt <= 5; ++leadPawnsCnt) + for (File f = FILE_A; f <= FILE_D; ++f) { + // Restart the index at every file because TB table is split + // by file, so we can reuse the same index for different files. + int idx = 0; + + // Sum all possible combinations for a given file, starting with + // the leading pawn on rank 2 and increasing the rank. + for (Rank r = RANK_2; r <= RANK_7; ++r) { + Square sq = make_square(f, r); + + // Compute MapPawns[] at first pass. + // If sq is the leading pawn square, any other pawn cannot be + // below or more toward the edge of sq. There are 47 available + // squares when sq = a2 and reduced by 2 for any rank increase + // due to mirroring: sq == a3 -> no a2, h2, so MapPawns[a3] = 45 + if (leadPawnsCnt == 1) { + MapPawns[sq] = availableSquares--; + MapPawns[flip_file(sq)] = availableSquares--; + } + LeadPawnIdx[leadPawnsCnt][sq] = idx; + idx += Binomial[leadPawnsCnt - 1][MapPawns[sq]]; } - LeadPawnIdx[leadPawnsCnt][sq] = idx; - idx += Binomial[leadPawnsCnt - 1][MapPawns[sq]]; + // After a file is traversed, store the cumulated per-file index + LeadPawnsSize[leadPawnsCnt][f] = idx; } - // After a file is traversed, store the cumulated per-file index - LeadPawnsSize[leadPawnsCnt][f] = idx; - } - // Add entries in TB tables if the corresponding ".rtbw" file exists - for (PieceType p1 = PAWN; p1 < KING; ++p1) { - TBTables.add({KING, p1, KING}); + // Add entries in TB tables if the corresponding ".rtbw" file exists + for (PieceType p1 = PAWN; p1 < KING; ++p1) { + TBTables.add({KING, p1, KING}); - for (PieceType p2 = PAWN; p2 <= p1; ++p2) { - TBTables.add({KING, p1, p2, KING}); - TBTables.add({KING, p1, KING, p2}); + for (PieceType p2 = PAWN; p2 <= p1; ++p2) { + TBTables.add({KING, p1, p2, KING}); + TBTables.add({KING, p1, KING, p2}); - for (PieceType p3 = PAWN; p3 < KING; ++p3) - TBTables.add({KING, p1, p2, KING, p3}); + for (PieceType p3 = PAWN; p3 < KING; ++p3) TBTables.add({KING, p1, p2, KING, p3}); - for (PieceType p3 = PAWN; p3 <= p2; ++p3) { - TBTables.add({KING, p1, p2, p3, KING}); + for (PieceType p3 = PAWN; p3 <= p2; ++p3) { + TBTables.add({KING, p1, p2, p3, KING}); - for (PieceType p4 = PAWN; p4 <= p3; ++p4) { - TBTables.add({KING, p1, p2, p3, p4, KING}); + for (PieceType p4 = PAWN; p4 <= p3; ++p4) { + TBTables.add({KING, p1, p2, p3, p4, KING}); - for (PieceType p5 = PAWN; p5 <= p4; ++p5) - TBTables.add({KING, p1, p2, p3, p4, p5, KING}); + for (PieceType p5 = PAWN; p5 <= p4; ++p5) + TBTables.add({KING, p1, p2, p3, p4, p5, KING}); - for (PieceType p5 = PAWN; p5 < KING; ++p5) - TBTables.add({KING, p1, p2, p3, p4, KING, p5}); - } + for (PieceType p5 = PAWN; p5 < KING; ++p5) + TBTables.add({KING, p1, p2, p3, p4, KING, p5}); + } - for (PieceType p4 = PAWN; p4 < KING; ++p4) { - TBTables.add({KING, p1, p2, p3, KING, p4}); + for (PieceType p4 = PAWN; p4 < KING; ++p4) { + TBTables.add({KING, p1, p2, p3, KING, p4}); - for (PieceType p5 = PAWN; p5 <= p4; ++p5) - TBTables.add({KING, p1, p2, p3, KING, p4, p5}); + for (PieceType p5 = PAWN; p5 <= p4; ++p5) + TBTables.add({KING, p1, p2, p3, KING, p4, p5}); + } } - } - for (PieceType p3 = PAWN; p3 <= p1; ++p3) - for (PieceType p4 = PAWN; p4 <= (p1 == p3 ? p2 : p3); ++p4) - TBTables.add({KING, p1, p2, KING, p3, p4}); + for (PieceType p3 = PAWN; p3 <= p1; ++p3) + for (PieceType p4 = PAWN; p4 <= (p1 == p3 ? p2 : p3); ++p4) + TBTables.add({KING, p1, p2, KING, p3, p4}); + } } + + sync_cout << "info string Found " << TBTables.size() << " tablebases" << sync_endl; } - sync_cout << "info string Found " << TBTables.size() << " tablebases" << sync_endl; -} - -// Probe the WDL table for a particular position. -// If *result != FAIL, the probe was successful. -// The return value is from the point of view of the side to move: -// -2 : loss -// -1 : loss, but draw under 50-move rule -// 0 : draw -// 1 : win, but draw under 50-move rule -// 2 : win -WDLScore Tablebases::probe_wdl(Position& pos, ProbeState* result) { - - *result = OK; - return search(pos, result); -} - -// Probe the DTZ table for a particular position. -// If *result != FAIL, the probe was successful. -// The return value is from the point of view of the side to move: -// n < -100 : loss, but draw under 50-move rule -// -100 <= n < -1 : loss in n ply (assuming 50-move counter == 0) -// -1 : loss, the side to move is mated -// 0 : draw -// 1 < n <= 100 : win in n ply (assuming 50-move counter == 0) -// 100 < n : win, but draw under 50-move rule -// -// The return value n can be off by 1: a return value -n can mean a loss -// in n+1 ply and a return value +n can mean a win in n+1 ply. This -// cannot happen for tables with positions exactly on the "edge" of -// the 50-move rule. -// -// This implies that if dtz > 0 is returned, the position is certainly -// a win if dtz + 50-move-counter <= 99. Care must be taken that the engine -// picks moves that preserve dtz + 50-move-counter <= 99. -// -// If n = 100 immediately after a capture or pawn move, then the position -// is also certainly a win, and during the whole phase until the next -// capture or pawn move, the inequality to be preserved is -// dtz + 50-move-counter <= 100. -// -// In short, if a move is available resulting in dtz + 50-move-counter <= 99, -// then do not accept moves leading to dtz + 50-move-counter == 100. -int Tablebases::probe_dtz(Position& pos, ProbeState* result) { - - *result = OK; - WDLScore wdl = search(pos, result); - - if (*result == FAIL || wdl == WDLDraw) // DTZ tables don't store draws - return 0; - - // DTZ stores a 'don't care' value in this case, or even a plain wrong - // one as in case the best move is a losing ep, so it cannot be probed. - if (*result == ZEROING_BEST_MOVE) - return dtz_before_zeroing(wdl); - - int dtz = probe_table(pos, result, wdl); - - if (*result == FAIL) - return 0; - - if (*result != CHANGE_STM) - return (dtz + 100 * (wdl == WDLBlessedLoss || wdl == WDLCursedWin)) * sign_of(wdl); - - // DTZ stores results for the other side, so we need to do a 1-ply search and - // find the winning move that minimizes DTZ. - StateInfo st; - int minDTZ = 0xFFFF; - - for (const Move move : MoveList(pos)) - { - bool zeroing = pos.capture(move) || type_of(pos.moved_piece(move)) == PAWN; - - pos.do_move(move, st); - - // For zeroing moves we want the dtz of the move _before_ doing it, - // otherwise we will get the dtz of the next move sequence. Search the - // position after the move to get the score sign (because even in a - // winning position we could make a losing capture or going for a draw). - dtz = zeroing ? -dtz_before_zeroing(search(pos, result)) - : -probe_dtz(pos, result); - - // If the move mates, force minDTZ to 1 - if (dtz == 1 && pos.checkers() && MoveList(pos).size() == 0) - minDTZ = 1; - - // Convert result from 1-ply search. Zeroing moves are already accounted - // by dtz_before_zeroing() that returns the DTZ of the previous move. - if (!zeroing) - dtz += sign_of(dtz); - - // Skip the draws and if we are winning only pick positive dtz - if (dtz < minDTZ && sign_of(dtz) == sign_of(wdl)) - minDTZ = dtz; - - pos.undo_move(move); - - if (*result == FAIL) - return 0; + // Probe the WDL table for a particular position. + // If *result != FAIL, the probe was successful. + // The return value is from the point of view of the side to move: + // -2 : loss + // -1 : loss, but draw under 50-move rule + // 0 : draw + // 1 : win, but draw under 50-move rule + // 2 : win + WDLScore Tablebases::probe_wdl(Position& pos, ProbeState* result) { + + *result = OK; + return search(pos, result); } - // When there are no legal moves, the position is mate: we return -1 - return minDTZ == 0xFFFF ? -1 : minDTZ; -} + // Probe the DTZ table for a particular position. + // If *result != FAIL, the probe was successful. + // The return value is from the point of view of the side to move: + // n < -100 : loss, but draw under 50-move rule + // -100 <= n < -1 : loss in n ply (assuming 50-move counter == 0) + // -1 : loss, the side to move is mated + // 0 : draw + // 1 < n <= 100 : win in n ply (assuming 50-move counter == 0) + // 100 < n : win, but draw under 50-move rule + // + // The return value n can be off by 1: a return value -n can mean a loss + // in n+1 ply and a return value +n can mean a win in n+1 ply. This + // cannot happen for tables with positions exactly on the "edge" of + // the 50-move rule. + // + // This implies that if dtz > 0 is returned, the position is certainly + // a win if dtz + 50-move-counter <= 99. Care must be taken that the engine + // picks moves that preserve dtz + 50-move-counter <= 99. + // + // If n = 100 immediately after a capture or pawn move, then the position + // is also certainly a win, and during the whole phase until the next + // capture or pawn move, the inequality to be preserved is + // dtz + 50-move-counter <= 100. + // + // In short, if a move is available resulting in dtz + 50-move-counter <= 99, + // then do not accept moves leading to dtz + 50-move-counter == 100. + int Tablebases::probe_dtz(Position& pos, ProbeState* result) { + + *result = OK; + WDLScore wdl = search(pos, result); + if (*result == FAIL || wdl == WDLDraw) // DTZ tables don't store draws + return 0; -// Use the DTZ tables to rank root moves. -// -// A return value false indicates that not all probes were successful. -bool Tablebases::root_probe(Position& pos, Search::RootMoves& rootMoves) { + // DTZ stores a 'don't care' value in this case, or even a plain wrong + // one as in case the best move is a losing ep, so it cannot be probed. + if (*result == ZEROING_BEST_MOVE) return dtz_before_zeroing(wdl); - ProbeState result = OK; - StateInfo st; + int dtz = probe_table(pos, result, wdl); - // Obtain 50-move counter for the root position - int cnt50 = pos.rule50_count(); + if (*result == FAIL) return 0; - // Check whether a position was repeated since the last zeroing move. - bool rep = pos.has_repeated(); + if (*result != CHANGE_STM) + return (dtz + 100 * (wdl == WDLBlessedLoss || wdl == WDLCursedWin)) * sign_of(wdl); - int dtz, bound = Options["Syzygy50MoveRule"] ? (MAX_DTZ - 100) : 1; + // DTZ stores results for the other side, so we need to do a 1-ply search and + // find the winning move that minimizes DTZ. + StateInfo st; + int minDTZ = 0xFFFF; - // Probe and rank each move - for (auto& m : rootMoves) - { - pos.do_move(m.pv[0], st); + for (const Move move : MoveList(pos)) { + bool zeroing = pos.capture(move) || type_of(pos.moved_piece(move)) == PAWN; - // Calculate dtz for the current move counting from the root position - if (pos.rule50_count() == 0) - { - // In case of a zeroing move, dtz is one of -101/-1/0/1/101 - WDLScore wdl = -probe_wdl(pos, &result); - dtz = dtz_before_zeroing(wdl); - } - else if (pos.is_draw(1)) - { - // In case a root move leads to a draw by repetition or - // 50-move rule, we set dtz to zero. Note: since we are - // only 1 ply from the root, this must be a true 3-fold - // repetition inside the game history. - dtz = 0; + pos.do_move(move, st); + + // For zeroing moves we want the dtz of the move _before_ doing it, + // otherwise we will get the dtz of the next move sequence. Search the + // position after the move to get the score sign (because even in a + // winning position we could make a losing capture or going for a draw). + dtz = + zeroing ? -dtz_before_zeroing(search(pos, result)) : -probe_dtz(pos, result); + + // If the move mates, force minDTZ to 1 + if (dtz == 1 && pos.checkers() && MoveList(pos).size() == 0) minDTZ = 1; + + // Convert result from 1-ply search. Zeroing moves are already accounted + // by dtz_before_zeroing() that returns the DTZ of the previous move. + if (!zeroing) dtz += sign_of(dtz); + + // Skip the draws and if we are winning only pick positive dtz + if (dtz < minDTZ && sign_of(dtz) == sign_of(wdl)) minDTZ = dtz; + + pos.undo_move(move); + + if (*result == FAIL) return 0; } - else - { - // Otherwise, take dtz for the new position and correct by 1 ply - dtz = -probe_dtz(pos, &result); - dtz = dtz > 0 ? dtz + 1 - : dtz < 0 ? dtz - 1 : dtz; + + // When there are no legal moves, the position is mate: we return -1 + return minDTZ == 0xFFFF ? -1 : minDTZ; + } + + + // Use the DTZ tables to rank root moves. + // + // A return value false indicates that not all probes were successful. + bool Tablebases::root_probe(Position& pos, Search::RootMoves& rootMoves) { + + ProbeState result = OK; + StateInfo st; + + // Obtain 50-move counter for the root position + int cnt50 = pos.rule50_count(); + + // Check whether a position was repeated since the last zeroing move. + bool rep = pos.has_repeated(); + + int dtz, bound = Options["Syzygy50MoveRule"] ? (MAX_DTZ - 100) : 1; + + // Probe and rank each move + for (auto& m : rootMoves) { + pos.do_move(m.pv[0], st); + + // Calculate dtz for the current move counting from the root position + if (pos.rule50_count() == 0) { + // In case of a zeroing move, dtz is one of -101/-1/0/1/101 + WDLScore wdl = -probe_wdl(pos, &result); + dtz = dtz_before_zeroing(wdl); + } else if (pos.is_draw(1)) { + // In case a root move leads to a draw by repetition or + // 50-move rule, we set dtz to zero. Note: since we are + // only 1 ply from the root, this must be a true 3-fold + // repetition inside the game history. + dtz = 0; + } else { + // Otherwise, take dtz for the new position and correct by 1 ply + dtz = -probe_dtz(pos, &result); + dtz = dtz > 0 ? dtz + 1 : dtz < 0 ? dtz - 1 : dtz; + } + + // Make sure that a mating move is assigned a dtz value of 1 + if (pos.checkers() && dtz == 2 && MoveList(pos).size() == 0) dtz = 1; + + pos.undo_move(m.pv[0]); + + if (result == FAIL) return false; + + // Better moves are ranked higher. Certain wins are ranked equally. + // Losing moves are ranked equally unless a 50-move draw is in sight. + int r = dtz > 0 ? (dtz + cnt50 <= 99 && !rep ? MAX_DTZ : MAX_DTZ - (dtz + cnt50)) : + dtz < 0 ? (-dtz * 2 + cnt50 < 100 ? -MAX_DTZ : -MAX_DTZ + (-dtz + cnt50)) : + 0; + m.tbRank = r; + + // Determine the score to be displayed for this move. Assign at least + // 1 cp to cursed wins and let it grow to 49 cp as the positions gets + // closer to a real win. + m.tbScore = r >= bound ? VALUE_MATE - MAX_PLY - 1 : + r > 0 ? Value((std::max(3, r - (MAX_DTZ - 200)) * int(PawnValue)) / 200) : + r == 0 ? VALUE_DRAW : + r > -bound ? + Value((std::min(-3, r + (MAX_DTZ - 200)) * int(PawnValue)) / 200) : + -VALUE_MATE + MAX_PLY + 1; } - // Make sure that a mating move is assigned a dtz value of 1 - if ( pos.checkers() - && dtz == 2 - && MoveList(pos).size() == 0) - dtz = 1; - - pos.undo_move(m.pv[0]); - - if (result == FAIL) - return false; - - // Better moves are ranked higher. Certain wins are ranked equally. - // Losing moves are ranked equally unless a 50-move draw is in sight. - int r = dtz > 0 ? (dtz + cnt50 <= 99 && !rep ? MAX_DTZ : MAX_DTZ - (dtz + cnt50)) - : dtz < 0 ? (-dtz * 2 + cnt50 < 100 ? -MAX_DTZ : -MAX_DTZ + (-dtz + cnt50)) - : 0; - m.tbRank = r; - - // Determine the score to be displayed for this move. Assign at least - // 1 cp to cursed wins and let it grow to 49 cp as the positions gets - // closer to a real win. - m.tbScore = r >= bound ? VALUE_MATE - MAX_PLY - 1 - : r > 0 ? Value((std::max( 3, r - (MAX_DTZ - 200)) * int(PawnValue)) / 200) - : r == 0 ? VALUE_DRAW - : r > -bound ? Value((std::min(-3, r + (MAX_DTZ - 200)) * int(PawnValue)) / 200) - : -VALUE_MATE + MAX_PLY + 1; + return true; } - return true; -} + // Use the WDL tables to rank root moves. + // This is a fallback for the case that some or all DTZ tables are missing. + // + // A return value false indicates that not all probes were successful. + bool Tablebases::root_probe_wdl(Position& pos, Search::RootMoves& rootMoves) { -// Use the WDL tables to rank root moves. -// This is a fallback for the case that some or all DTZ tables are missing. -// -// A return value false indicates that not all probes were successful. -bool Tablebases::root_probe_wdl(Position& pos, Search::RootMoves& rootMoves) { + static const int WDL_to_rank[] = {-MAX_DTZ, -MAX_DTZ + 101, 0, MAX_DTZ - 101, MAX_DTZ}; - static const int WDL_to_rank[] = { -MAX_DTZ, -MAX_DTZ + 101, 0, MAX_DTZ - 101, MAX_DTZ }; + ProbeState result = OK; + StateInfo st; + WDLScore wdl; - ProbeState result = OK; - StateInfo st; - WDLScore wdl; + bool rule50 = Options["Syzygy50MoveRule"]; - bool rule50 = Options["Syzygy50MoveRule"]; + // Probe and rank each move + for (auto& m : rootMoves) { + pos.do_move(m.pv[0], st); - // Probe and rank each move - for (auto& m : rootMoves) - { - pos.do_move(m.pv[0], st); + if (pos.is_draw(1)) + wdl = WDLDraw; + else + wdl = -probe_wdl(pos, &result); - if (pos.is_draw(1)) - wdl = WDLDraw; - else - wdl = -probe_wdl(pos, &result); + pos.undo_move(m.pv[0]); - pos.undo_move(m.pv[0]); + if (result == FAIL) return false; - if (result == FAIL) - return false; + m.tbRank = WDL_to_rank[wdl + 2]; - m.tbRank = WDL_to_rank[wdl + 2]; + if (!rule50) wdl = wdl > WDLDraw ? WDLWin : wdl < WDLDraw ? WDLLoss : WDLDraw; + m.tbScore = WDL_to_value[wdl + 2]; + } - if (!rule50) - wdl = wdl > WDLDraw ? WDLWin - : wdl < WDLDraw ? WDLLoss : WDLDraw; - m.tbScore = WDL_to_value[wdl + 2]; + return true; } - return true; -} - -} // namespace Stockfish +} // namespace Stockfish diff --git a/src/syzygy/tbprobe.h b/src/syzygy/tbprobe.h index b2ba35ff4b0..e7ff6a66ebc 100644 --- a/src/syzygy/tbprobe.h +++ b/src/syzygy/tbprobe.h @@ -24,36 +24,36 @@ #include "../search.h" namespace Stockfish { -class Position; + class Position; } namespace Stockfish::Tablebases { -enum WDLScore { - WDLLoss = -2, // Loss - WDLBlessedLoss = -1, // Loss, but draw under 50-move rule - WDLDraw = 0, // Draw - WDLCursedWin = 1, // Win, but draw under 50-move rule - WDLWin = 2, // Win -}; - -// Possible states after a probing operation -enum ProbeState { - FAIL = 0, // Probe failed (missing file table) - OK = 1, // Probe successful - CHANGE_STM = -1, // DTZ should check the other side - ZEROING_BEST_MOVE = 2 // Best move zeroes DTZ (capture or pawn move) -}; - -extern int MaxCardinality; - -void init(const std::string& paths); -WDLScore probe_wdl(Position& pos, ProbeState* result); -int probe_dtz(Position& pos, ProbeState* result); -bool root_probe(Position& pos, Search::RootMoves& rootMoves); -bool root_probe_wdl(Position& pos, Search::RootMoves& rootMoves); -void rank_root_moves(Position& pos, Search::RootMoves& rootMoves); - -} // namespace Stockfish::Tablebases + enum WDLScore { + WDLLoss = -2, // Loss + WDLBlessedLoss = -1, // Loss, but draw under 50-move rule + WDLDraw = 0, // Draw + WDLCursedWin = 1, // Win, but draw under 50-move rule + WDLWin = 2, // Win + }; + + // Possible states after a probing operation + enum ProbeState { + FAIL = 0, // Probe failed (missing file table) + OK = 1, // Probe successful + CHANGE_STM = -1, // DTZ should check the other side + ZEROING_BEST_MOVE = 2 // Best move zeroes DTZ (capture or pawn move) + }; + + extern int MaxCardinality; + + void init(const std::string& paths); + WDLScore probe_wdl(Position& pos, ProbeState* result); + int probe_dtz(Position& pos, ProbeState* result); + bool root_probe(Position& pos, Search::RootMoves& rootMoves); + bool root_probe_wdl(Position& pos, Search::RootMoves& rootMoves); + void rank_root_moves(Position& pos, Search::RootMoves& rootMoves); + +} // namespace Stockfish::Tablebases #endif diff --git a/src/thread.cpp b/src/thread.cpp index 60f760ed46e..f3c9268d991 100644 --- a/src/thread.cpp +++ b/src/thread.cpp @@ -37,242 +37,227 @@ namespace Stockfish { -ThreadPool Threads; // Global object + ThreadPool Threads; // Global object -/// Thread constructor launches the thread and waits until it goes to sleep -/// in idle_loop(). Note that 'searching' and 'exit' should be already set. + /// Thread constructor launches the thread and waits until it goes to sleep + /// in idle_loop(). Note that 'searching' and 'exit' should be already set. -Thread::Thread(size_t n) : idx(n), stdThread(&Thread::idle_loop, this) { + Thread::Thread(size_t n) : idx(n), stdThread(&Thread::idle_loop, this) { - wait_for_search_finished(); -} + wait_for_search_finished(); + } -/// Thread destructor wakes up the thread in idle_loop() and waits -/// for its termination. Thread should be already waiting. + /// Thread destructor wakes up the thread in idle_loop() and waits + /// for its termination. Thread should be already waiting. -Thread::~Thread() { + Thread::~Thread() { - assert(!searching); + assert(!searching); - exit = true; - start_searching(); - stdThread.join(); -} + exit = true; + start_searching(); + stdThread.join(); + } -/// Thread::clear() reset histories, usually before a new game + /// Thread::clear() reset histories, usually before a new game -void Thread::clear() { + void Thread::clear() { - counterMoves.fill(MOVE_NONE); - mainHistory.fill(0); - captureHistory.fill(0); + counterMoves.fill(MOVE_NONE); + mainHistory.fill(0); + captureHistory.fill(0); - for (bool inCheck : { false, true }) - for (StatsType c : { NoCaptures, Captures }) - for (auto& to : continuationHistory[inCheck][c]) - for (auto& h : to) - h->fill(-71); -} + for (bool inCheck : {false, true}) + for (StatsType c : {NoCaptures, Captures}) + for (auto& to : continuationHistory[inCheck][c]) + for (auto& h : to) h->fill(-71); + } -/// Thread::start_searching() wakes up the thread that will start the search + /// Thread::start_searching() wakes up the thread that will start the search -void Thread::start_searching() { - mutex.lock(); - searching = true; - mutex.unlock(); // Unlock before notifying saves a few CPU-cycles - cv.notify_one(); // Wake up the thread in idle_loop() -} + void Thread::start_searching() { + mutex.lock(); + searching = true; + mutex.unlock(); // Unlock before notifying saves a few CPU-cycles + cv.notify_one(); // Wake up the thread in idle_loop() + } -/// Thread::wait_for_search_finished() blocks on the condition variable -/// until the thread has finished searching. + /// Thread::wait_for_search_finished() blocks on the condition variable + /// until the thread has finished searching. -void Thread::wait_for_search_finished() { + void Thread::wait_for_search_finished() { - std::unique_lock lk(mutex); - cv.wait(lk, [&]{ return !searching; }); -} + std::unique_lock lk(mutex); + cv.wait(lk, [&] { return !searching; }); + } -/// Thread::idle_loop() is where the thread is parked, blocked on the -/// condition variable, when it has no work to do. + /// Thread::idle_loop() is where the thread is parked, blocked on the + /// condition variable, when it has no work to do. -void Thread::idle_loop() { + void Thread::idle_loop() { - // If OS already scheduled us on a different group than 0 then don't overwrite - // the choice, eventually we are one of many one-threaded processes running on - // some Windows NUMA hardware, for instance in fishtest. To make it simple, - // just check if running threads are below a threshold, in this case all this - // NUMA machinery is not needed. - if (Options["Threads"] > 8) - WinProcGroup::bindThisThread(idx); + // If OS already scheduled us on a different group than 0 then don't overwrite + // the choice, eventually we are one of many one-threaded processes running on + // some Windows NUMA hardware, for instance in fishtest. To make it simple, + // just check if running threads are below a threshold, in this case all this + // NUMA machinery is not needed. + if (Options["Threads"] > 8) WinProcGroup::bindThisThread(idx); - while (true) - { - std::unique_lock lk(mutex); - searching = false; - cv.notify_one(); // Wake up anyone waiting for search finished - cv.wait(lk, [&]{ return searching; }); + while (true) { + std::unique_lock lk(mutex); + searching = false; + cv.notify_one(); // Wake up anyone waiting for search finished + cv.wait(lk, [&] { return searching; }); - if (exit) - return; + if (exit) return; - lk.unlock(); + lk.unlock(); - search(); - } -} + search(); + } + } -/// ThreadPool::set() creates/destroys threads to match the requested number. -/// Created and launched threads will immediately go to sleep in idle_loop. -/// Upon resizing, threads are recreated to allow for binding if necessary. + /// ThreadPool::set() creates/destroys threads to match the requested number. + /// Created and launched threads will immediately go to sleep in idle_loop. + /// Upon resizing, threads are recreated to allow for binding if necessary. -void ThreadPool::set(size_t requested) { + void ThreadPool::set(size_t requested) { - if (threads.size() > 0) // destroy any existing thread(s) - { - main()->wait_for_search_finished(); + if (threads.size() > 0) // destroy any existing thread(s) + { + main()->wait_for_search_finished(); - while (threads.size() > 0) - delete threads.back(), threads.pop_back(); - } + while (threads.size() > 0) delete threads.back(), threads.pop_back(); + } - if (requested > 0) // create new thread(s) - { - threads.push_back(new MainThread(0)); + if (requested > 0) // create new thread(s) + { + threads.push_back(new MainThread(0)); - while (threads.size() < requested) - threads.push_back(new Thread(threads.size())); - clear(); + while (threads.size() < requested) threads.push_back(new Thread(threads.size())); + clear(); - // Reallocate the hash with the new threadpool size - TT.resize(size_t(Options["Hash"])); + // Reallocate the hash with the new threadpool size + TT.resize(size_t(Options["Hash"])); - // Init thread number dependent search params. - Search::init(); - } -} + // Init thread number dependent search params. + Search::init(); + } + } -/// ThreadPool::clear() sets threadPool data to initial values + /// ThreadPool::clear() sets threadPool data to initial values -void ThreadPool::clear() { + void ThreadPool::clear() { - for (Thread* th : threads) - th->clear(); + for (Thread* th : threads) th->clear(); - main()->callsCnt = 0; - main()->bestPreviousScore = VALUE_INFINITE; - main()->bestPreviousAverageScore = VALUE_INFINITE; - main()->previousTimeReduction = 1.0; -} + main()->callsCnt = 0; + main()->bestPreviousScore = VALUE_INFINITE; + main()->bestPreviousAverageScore = VALUE_INFINITE; + main()->previousTimeReduction = 1.0; + } -/// ThreadPool::start_thinking() wakes up main thread waiting in idle_loop() and -/// returns immediately. Main thread will wake up other threads and start the search. + /// ThreadPool::start_thinking() wakes up main thread waiting in idle_loop() and + /// returns immediately. Main thread will wake up other threads and start the search. -void ThreadPool::start_thinking(Position& pos, StateListPtr& states, - const Search::LimitsType& limits, bool ponderMode) { + void ThreadPool::start_thinking(Position& pos, StateListPtr& states, + const Search::LimitsType& limits, bool ponderMode) { - main()->wait_for_search_finished(); + main()->wait_for_search_finished(); - main()->stopOnPonderhit = stop = false; - increaseDepth = true; - main()->ponder = ponderMode; - Search::Limits = limits; - Search::RootMoves rootMoves; + main()->stopOnPonderhit = stop = false; + increaseDepth = true; + main()->ponder = ponderMode; + Search::Limits = limits; + Search::RootMoves rootMoves; - for (const auto& m : MoveList(pos)) - if ( limits.searchmoves.empty() - || std::count(limits.searchmoves.begin(), limits.searchmoves.end(), m)) - rootMoves.emplace_back(m); + for (const auto& m : MoveList(pos)) + if (limits.searchmoves.empty() || + std::count(limits.searchmoves.begin(), limits.searchmoves.end(), m)) + rootMoves.emplace_back(m); - if (!rootMoves.empty()) - Tablebases::rank_root_moves(pos, rootMoves); + if (!rootMoves.empty()) Tablebases::rank_root_moves(pos, rootMoves); - // After ownership transfer 'states' becomes empty, so if we stop the search - // and call 'go' again without setting a new position states.get() == nullptr. - assert(states.get() || setupStates.get()); + // After ownership transfer 'states' becomes empty, so if we stop the search + // and call 'go' again without setting a new position states.get() == nullptr. + assert(states.get() || setupStates.get()); - if (states.get()) - setupStates = std::move(states); // Ownership transfer, states is now empty + if (states.get()) + setupStates = std::move(states); // Ownership transfer, states is now empty - // We use Position::set() to set root position across threads. But there are - // some StateInfo fields (previous, pliesFromNull, capturedPiece) that cannot - // be deduced from a fen string, so set() clears them and they are set from - // setupStates->back() later. The rootState is per thread, earlier states are shared - // since they are read-only. - for (Thread* th : threads) - { - th->nodes = th->tbHits = th->nmpMinPly = th->bestMoveChanges = 0; - th->rootDepth = th->completedDepth = 0; - th->rootMoves = rootMoves; - th->rootPos.set(pos.fen(), pos.is_chess960(), &th->rootState, th); - th->rootState = setupStates->back(); - th->rootSimpleEval = Eval::simple_eval(pos, pos.side_to_move()); - } + // We use Position::set() to set root position across threads. But there are + // some StateInfo fields (previous, pliesFromNull, capturedPiece) that cannot + // be deduced from a fen string, so set() clears them and they are set from + // setupStates->back() later. The rootState is per thread, earlier states are shared + // since they are read-only. + for (Thread* th : threads) { + th->nodes = th->tbHits = th->nmpMinPly = th->bestMoveChanges = 0; + th->rootDepth = th->completedDepth = 0; + th->rootMoves = rootMoves; + th->rootPos.set(pos.fen(), pos.is_chess960(), &th->rootState, th); + th->rootState = setupStates->back(); + th->rootSimpleEval = Eval::simple_eval(pos, pos.side_to_move()); + } - main()->start_searching(); -} + main()->start_searching(); + } -Thread* ThreadPool::get_best_thread() const { + Thread* ThreadPool::get_best_thread() const { - Thread* bestThread = threads.front(); - std::map votes; - Value minScore = VALUE_NONE; + Thread* bestThread = threads.front(); + std::map votes; + Value minScore = VALUE_NONE; - // Find minimum score of all threads - for (Thread* th: threads) - minScore = std::min(minScore, th->rootMoves[0].score); + // Find minimum score of all threads + for (Thread* th : threads) minScore = std::min(minScore, th->rootMoves[0].score); - // Vote according to score and depth, and select the best thread - auto thread_value = [minScore](Thread* th) { + // Vote according to score and depth, and select the best thread + auto thread_value = [minScore](Thread* th) { return (th->rootMoves[0].score - minScore + 14) * int(th->completedDepth); }; - for (Thread* th : threads) - votes[th->rootMoves[0].pv[0]] += thread_value(th); - - for (Thread* th : threads) - if (abs(bestThread->rootMoves[0].score) >= VALUE_TB_WIN_IN_MAX_PLY) - { - // Make sure we pick the shortest mate / TB conversion or stave off mate the longest - if (th->rootMoves[0].score > bestThread->rootMoves[0].score) + for (Thread* th : threads) votes[th->rootMoves[0].pv[0]] += thread_value(th); + + for (Thread* th : threads) + if (abs(bestThread->rootMoves[0].score) >= VALUE_TB_WIN_IN_MAX_PLY) { + // Make sure we pick the shortest mate / TB conversion or stave off mate the longest + if (th->rootMoves[0].score > bestThread->rootMoves[0].score) bestThread = th; + } else if (th->rootMoves[0].score >= VALUE_TB_WIN_IN_MAX_PLY || + (th->rootMoves[0].score > VALUE_TB_LOSS_IN_MAX_PLY && + (votes[th->rootMoves[0].pv[0]] > votes[bestThread->rootMoves[0].pv[0]] || + (votes[th->rootMoves[0].pv[0]] == votes[bestThread->rootMoves[0].pv[0]] && + thread_value(th) * int(th->rootMoves[0].pv.size() > 2) > + thread_value(bestThread) * + int(bestThread->rootMoves[0].pv.size() > 2))))) bestThread = th; - } - else if ( th->rootMoves[0].score >= VALUE_TB_WIN_IN_MAX_PLY - || ( th->rootMoves[0].score > VALUE_TB_LOSS_IN_MAX_PLY - && ( votes[th->rootMoves[0].pv[0]] > votes[bestThread->rootMoves[0].pv[0]] - || ( votes[th->rootMoves[0].pv[0]] == votes[bestThread->rootMoves[0].pv[0]] - && thread_value(th) * int(th->rootMoves[0].pv.size() > 2) - > thread_value(bestThread) * int(bestThread->rootMoves[0].pv.size() > 2))))) - bestThread = th; - return bestThread; -} + return bestThread; + } -/// Start non-main threads + /// Start non-main threads -void ThreadPool::start_searching() { + void ThreadPool::start_searching() { - for (Thread* th : threads) - if (th != threads.front()) - th->start_searching(); -} + for (Thread* th : threads) + if (th != threads.front()) th->start_searching(); + } -/// Wait for non-main threads + /// Wait for non-main threads -void ThreadPool::wait_for_search_finished() const { + void ThreadPool::wait_for_search_finished() const { - for (Thread* th : threads) - if (th != threads.front()) - th->wait_for_search_finished(); -} + for (Thread* th : threads) + if (th != threads.front()) th->wait_for_search_finished(); + } -} // namespace Stockfish +} // namespace Stockfish diff --git a/src/thread.h b/src/thread.h index 8d0adcf0340..742b77450f5 100644 --- a/src/thread.h +++ b/src/thread.h @@ -34,107 +34,106 @@ namespace Stockfish { -/// Thread class keeps together all the thread-related stuff. We use -/// per-thread pawn and material hash tables so that once we get a -/// pointer to an entry its life time is unlimited and we don't have -/// to care about someone changing the entry under our feet. - -class Thread { - - std::mutex mutex; - std::condition_variable cv; - size_t idx; - bool exit = false, searching = true; // Set before starting std::thread - NativeThread stdThread; - -public: - explicit Thread(size_t); - virtual ~Thread(); - virtual void search(); - void clear(); - void idle_loop(); - void start_searching(); - void wait_for_search_finished(); - size_t id() const { return idx; } - - size_t pvIdx, pvLast; - std::atomic nodes, tbHits, bestMoveChanges; - int selDepth, nmpMinPly; - Value bestValue, optimism[COLOR_NB]; - - Position rootPos; - StateInfo rootState; - Search::RootMoves rootMoves; - Depth rootDepth, completedDepth; - Value rootDelta; - Value rootSimpleEval; - CounterMoveHistory counterMoves; - ButterflyHistory mainHistory; - CapturePieceToHistory captureHistory; - ContinuationHistory continuationHistory[2][2]; -}; - - -/// MainThread is a derived class specific for main thread - -struct MainThread : public Thread { - - using Thread::Thread; - - void search() override; - void check_time(); - - double previousTimeReduction; - Value bestPreviousScore; - Value bestPreviousAverageScore; - Value iterValue[4]; - int callsCnt; - bool stopOnPonderhit; - std::atomic_bool ponder; -}; - - -/// ThreadPool struct handles all the threads-related stuff like init, starting, -/// parking and, most importantly, launching a thread. All the access to threads -/// is done through this class. - -struct ThreadPool { - - void start_thinking(Position&, StateListPtr&, const Search::LimitsType&, bool = false); - void clear(); - void set(size_t); - - MainThread* main() const { return static_cast(threads.front()); } - uint64_t nodes_searched() const { return accumulate(&Thread::nodes); } - uint64_t tb_hits() const { return accumulate(&Thread::tbHits); } - Thread* get_best_thread() const; - void start_searching(); - void wait_for_search_finished() const; - - std::atomic_bool stop, increaseDepth; - - auto cbegin() const noexcept { return threads.cbegin(); } - auto begin() noexcept { return threads.begin(); } - auto end() noexcept { return threads.end(); } - auto cend() const noexcept { return threads.cend(); } - auto size() const noexcept { return threads.size(); } - auto empty() const noexcept { return threads.empty(); } - -private: - StateListPtr setupStates; - std::vector threads; - - uint64_t accumulate(std::atomic Thread::* member) const { - - uint64_t sum = 0; - for (Thread* th : threads) - sum += (th->*member).load(std::memory_order_relaxed); - return sum; - } -}; - -extern ThreadPool Threads; - -} // namespace Stockfish - -#endif // #ifndef THREAD_H_INCLUDED + /// Thread class keeps together all the thread-related stuff. We use + /// per-thread pawn and material hash tables so that once we get a + /// pointer to an entry its life time is unlimited and we don't have + /// to care about someone changing the entry under our feet. + + class Thread { + + std::mutex mutex; + std::condition_variable cv; + size_t idx; + bool exit = false, searching = true; // Set before starting std::thread + NativeThread stdThread; + + public: + explicit Thread(size_t); + virtual ~Thread(); + virtual void search(); + void clear(); + void idle_loop(); + void start_searching(); + void wait_for_search_finished(); + size_t id() const { return idx; } + + size_t pvIdx, pvLast; + std::atomic nodes, tbHits, bestMoveChanges; + int selDepth, nmpMinPly; + Value bestValue, optimism[COLOR_NB]; + + Position rootPos; + StateInfo rootState; + Search::RootMoves rootMoves; + Depth rootDepth, completedDepth; + Value rootDelta; + Value rootSimpleEval; + CounterMoveHistory counterMoves; + ButterflyHistory mainHistory; + CapturePieceToHistory captureHistory; + ContinuationHistory continuationHistory[2][2]; + }; + + + /// MainThread is a derived class specific for main thread + + struct MainThread: public Thread { + + using Thread::Thread; + + void search() override; + void check_time(); + + double previousTimeReduction; + Value bestPreviousScore; + Value bestPreviousAverageScore; + Value iterValue[4]; + int callsCnt; + bool stopOnPonderhit; + std::atomic_bool ponder; + }; + + + /// ThreadPool struct handles all the threads-related stuff like init, starting, + /// parking and, most importantly, launching a thread. All the access to threads + /// is done through this class. + + struct ThreadPool { + + void start_thinking(Position&, StateListPtr&, const Search::LimitsType&, bool = false); + void clear(); + void set(size_t); + + MainThread* main() const { return static_cast(threads.front()); } + uint64_t nodes_searched() const { return accumulate(&Thread::nodes); } + uint64_t tb_hits() const { return accumulate(&Thread::tbHits); } + Thread* get_best_thread() const; + void start_searching(); + void wait_for_search_finished() const; + + std::atomic_bool stop, increaseDepth; + + auto cbegin() const noexcept { return threads.cbegin(); } + auto begin() noexcept { return threads.begin(); } + auto end() noexcept { return threads.end(); } + auto cend() const noexcept { return threads.cend(); } + auto size() const noexcept { return threads.size(); } + auto empty() const noexcept { return threads.empty(); } + + private: + StateListPtr setupStates; + std::vector threads; + + uint64_t accumulate(std::atomic Thread::*member) const { + + uint64_t sum = 0; + for (Thread* th : threads) sum += (th->*member).load(std::memory_order_relaxed); + return sum; + } + }; + + extern ThreadPool Threads; + +} // namespace Stockfish + +#endif // #ifndef THREAD_H_INCLUDED diff --git a/src/thread_win32_osx.h b/src/thread_win32_osx.h index 330a8341dd8..c061cc01cf4 100644 --- a/src/thread_win32_osx.h +++ b/src/thread_win32_osx.h @@ -29,46 +29,44 @@ #if defined(__APPLE__) || defined(__MINGW32__) || defined(__MINGW64__) || defined(USE_PTHREADS) -#include + #include namespace Stockfish { -static const size_t TH_STACK_SIZE = 8 * 1024 * 1024; + static const size_t TH_STACK_SIZE = 8 * 1024 * 1024; -template > -void* start_routine(void* ptr) -{ - P* p = reinterpret_cast(ptr); - (p->first->*(p->second))(); // Call member function pointer - delete p; - return nullptr; -} + template> void* start_routine(void* ptr) { + P* p = reinterpret_cast(ptr); + (p->first->*(p->second))(); // Call member function pointer + delete p; + return nullptr; + } -class NativeThread { + class NativeThread { - pthread_t thread; + pthread_t thread; -public: - template> - explicit NativeThread(void(T::*fun)(), T* obj) { - pthread_attr_t attr_storage, *attr = &attr_storage; - pthread_attr_init(attr); - pthread_attr_setstacksize(attr, TH_STACK_SIZE); - pthread_create(&thread, attr, start_routine, new P(obj, fun)); - } - void join() { pthread_join(thread, nullptr); } -}; + public: + template> + explicit NativeThread(void (T::*fun)(), T* obj) { + pthread_attr_t attr_storage, *attr = &attr_storage; + pthread_attr_init(attr); + pthread_attr_setstacksize(attr, TH_STACK_SIZE); + pthread_create(&thread, attr, start_routine, new P(obj, fun)); + } + void join() { pthread_join(thread, nullptr); } + }; -} // namespace Stockfish +} // namespace Stockfish -#else // Default case: use STL classes +#else // Default case: use STL classes namespace Stockfish { -using NativeThread = std::thread; + using NativeThread = std::thread; -} // namespace Stockfish +} // namespace Stockfish #endif -#endif // #ifndef THREAD_WIN32_OSX_H_INCLUDED +#endif // #ifndef THREAD_WIN32_OSX_H_INCLUDED diff --git a/src/timeman.cpp b/src/timeman.cpp index 5e57f8f98c5..720e0d52063 100644 --- a/src/timeman.cpp +++ b/src/timeman.cpp @@ -26,84 +26,80 @@ namespace Stockfish { -TimeManagement Time; // Our global time management object - - -/// TimeManagement::init() is called at the beginning of the search and calculates -/// the bounds of time allowed for the current game ply. We currently support: -// 1) x basetime (+ z increment) -// 2) x moves in y seconds (+ z increment) - -void TimeManagement::init(Search::LimitsType& limits, Color us, int ply) { - - // if we have no time, no need to initialize TM, except for the start time, - // which is used by movetime. - startTime = limits.startTime; - if (limits.time[us] == 0) - return; - - TimePoint moveOverhead = TimePoint(Options["Move Overhead"]); - TimePoint slowMover = TimePoint(Options["Slow Mover"]); - TimePoint npmsec = TimePoint(Options["nodestime"]); - - // optScale is a percentage of available time to use for the current move. - // maxScale is a multiplier applied to optimumTime. - double optScale, maxScale; - - // If we have to play in 'nodes as time' mode, then convert from time - // to nodes, and use resulting values in time management formulas. - // WARNING: to avoid time losses, the given npmsec (nodes per millisecond) - // must be much lower than the real engine speed. - if (npmsec) - { - if (!availableNodes) // Only once at game start - availableNodes = npmsec * limits.time[us]; // Time is in msec - - // Convert from milliseconds to nodes - limits.time[us] = TimePoint(availableNodes); - limits.inc[us] *= npmsec; - limits.npmsec = npmsec; - } - - // Maximum move horizon of 50 moves - int mtg = limits.movestogo ? std::min(limits.movestogo, 50) : 50; - - // Make sure timeLeft is > 0 since we may use it as a divisor - TimePoint timeLeft = std::max(TimePoint(1), - limits.time[us] + limits.inc[us] * (mtg - 1) - moveOverhead * (2 + mtg)); - - // Use extra time with larger increments - double optExtra = std::clamp(1.0 + 12.0 * limits.inc[us] / limits.time[us], 1.0, 1.12); - - // A user may scale time usage by setting UCI option "Slow Mover" - // Default is 100 and changing this value will probably lose elo. - timeLeft = slowMover * timeLeft / 100; - - // x basetime (+ z increment) - // If there is a healthy increment, timeLeft can exceed actual available - // game time for the current move, so also cap to 20% of available game time. - if (limits.movestogo == 0) - { - optScale = std::min(0.0120 + std::pow(ply + 3.0, 0.45) * 0.0039, - 0.2 * limits.time[us] / double(timeLeft)) - * optExtra; - maxScale = std::min(7.0, 4.0 + ply / 12.0); - } - - // x moves in y seconds (+ z increment) - else - { - optScale = std::min((0.88 + ply / 116.4) / mtg, - 0.88 * limits.time[us] / double(timeLeft)); - maxScale = std::min(6.3, 1.5 + 0.11 * mtg); - } - - // Never use more than 80% of the available time for this move - optimumTime = TimePoint(optScale * timeLeft); - maximumTime = TimePoint(std::min(0.8 * limits.time[us] - moveOverhead, maxScale * optimumTime)) - 10; - - if (Options["Ponder"]) - optimumTime += optimumTime / 4; -} - -} // namespace Stockfish + TimeManagement Time; // Our global time management object + + + /// TimeManagement::init() is called at the beginning of the search and calculates + /// the bounds of time allowed for the current game ply. We currently support: + // 1) x basetime (+ z increment) + // 2) x moves in y seconds (+ z increment) + + void TimeManagement::init(Search::LimitsType& limits, Color us, int ply) { + + // if we have no time, no need to initialize TM, except for the start time, + // which is used by movetime. + startTime = limits.startTime; + if (limits.time[us] == 0) return; + + TimePoint moveOverhead = TimePoint(Options["Move Overhead"]); + TimePoint slowMover = TimePoint(Options["Slow Mover"]); + TimePoint npmsec = TimePoint(Options["nodestime"]); + + // optScale is a percentage of available time to use for the current move. + // maxScale is a multiplier applied to optimumTime. + double optScale, maxScale; + + // If we have to play in 'nodes as time' mode, then convert from time + // to nodes, and use resulting values in time management formulas. + // WARNING: to avoid time losses, the given npmsec (nodes per millisecond) + // must be much lower than the real engine speed. + if (npmsec) { + if (!availableNodes) // Only once at game start + availableNodes = npmsec * limits.time[us]; // Time is in msec + + // Convert from milliseconds to nodes + limits.time[us] = TimePoint(availableNodes); + limits.inc[us] *= npmsec; + limits.npmsec = npmsec; + } + + // Maximum move horizon of 50 moves + int mtg = limits.movestogo ? std::min(limits.movestogo, 50) : 50; + + // Make sure timeLeft is > 0 since we may use it as a divisor + TimePoint timeLeft = std::max(TimePoint(1), limits.time[us] + limits.inc[us] * (mtg - 1) - + moveOverhead * (2 + mtg)); + + // Use extra time with larger increments + double optExtra = std::clamp(1.0 + 12.0 * limits.inc[us] / limits.time[us], 1.0, 1.12); + + // A user may scale time usage by setting UCI option "Slow Mover" + // Default is 100 and changing this value will probably lose elo. + timeLeft = slowMover * timeLeft / 100; + + // x basetime (+ z increment) + // If there is a healthy increment, timeLeft can exceed actual available + // game time for the current move, so also cap to 20% of available game time. + if (limits.movestogo == 0) { + optScale = std::min(0.0120 + std::pow(ply + 3.0, 0.45) * 0.0039, + 0.2 * limits.time[us] / double(timeLeft)) * + optExtra; + maxScale = std::min(7.0, 4.0 + ply / 12.0); + } + + // x moves in y seconds (+ z increment) + else { + optScale = + std::min((0.88 + ply / 116.4) / mtg, 0.88 * limits.time[us] / double(timeLeft)); + maxScale = std::min(6.3, 1.5 + 0.11 * mtg); + } + + // Never use more than 80% of the available time for this move + optimumTime = TimePoint(optScale * timeLeft); + maximumTime = + TimePoint(std::min(0.8 * limits.time[us] - moveOverhead, maxScale * optimumTime)) - 10; + + if (Options["Ponder"]) optimumTime += optimumTime / 4; + } + +} // namespace Stockfish diff --git a/src/timeman.h b/src/timeman.h index 9ad6bdcccf9..ad3383fe1fc 100644 --- a/src/timeman.h +++ b/src/timeman.h @@ -28,27 +28,28 @@ namespace Stockfish { -/// The TimeManagement class computes the optimal time to think depending on -/// the maximum available time, the game move number and other parameters. + /// The TimeManagement class computes the optimal time to think depending on + /// the maximum available time, the game move number and other parameters. -class TimeManagement { -public: - void init(Search::LimitsType& limits, Color us, int ply); - TimePoint optimum() const { return optimumTime; } - TimePoint maximum() const { return maximumTime; } - TimePoint elapsed() const { return Search::Limits.npmsec ? - TimePoint(Threads.nodes_searched()) : now() - startTime; } + class TimeManagement { + public: + void init(Search::LimitsType& limits, Color us, int ply); + TimePoint optimum() const { return optimumTime; } + TimePoint maximum() const { return maximumTime; } + TimePoint elapsed() const { + return Search::Limits.npmsec ? TimePoint(Threads.nodes_searched()) : now() - startTime; + } - int64_t availableNodes; // When in 'nodes as time' mode + int64_t availableNodes; // When in 'nodes as time' mode -private: - TimePoint startTime; - TimePoint optimumTime; - TimePoint maximumTime; -}; + private: + TimePoint startTime; + TimePoint optimumTime; + TimePoint maximumTime; + }; -extern TimeManagement Time; + extern TimeManagement Time; -} // namespace Stockfish +} // namespace Stockfish -#endif // #ifndef TIMEMAN_H_INCLUDED +#endif // #ifndef TIMEMAN_H_INCLUDED diff --git a/src/tt.cpp b/src/tt.cpp index 1582121fd6d..ccbff3a518d 100644 --- a/src/tt.cpp +++ b/src/tt.cpp @@ -31,135 +31,129 @@ namespace Stockfish { -TranspositionTable TT; // Our global transposition table + TranspositionTable TT; // Our global transposition table -/// TTEntry::save() populates the TTEntry with a new node's data, possibly -/// overwriting an old position. Update is not atomic and can be racy. + /// TTEntry::save() populates the TTEntry with a new node's data, possibly + /// overwriting an old position. Update is not atomic and can be racy. -void TTEntry::save(Key k, Value v, bool pv, Bound b, Depth d, Move m, Value ev) { + void TTEntry::save(Key k, Value v, bool pv, Bound b, Depth d, Move m, Value ev) { - // Preserve any existing move for the same position - if (m || (uint16_t)k != key16) - move16 = (uint16_t)m; + // Preserve any existing move for the same position + if (m || (uint16_t) k != key16) move16 = (uint16_t) m; - // Overwrite less valuable entries (cheapest checks first) - if ( b == BOUND_EXACT - || (uint16_t)k != key16 - || d - DEPTH_OFFSET + 2 * pv > depth8 - 4) - { - assert(d > DEPTH_OFFSET); - assert(d < 256 + DEPTH_OFFSET); + // Overwrite less valuable entries (cheapest checks first) + if (b == BOUND_EXACT || (uint16_t) k != key16 || d - DEPTH_OFFSET + 2 * pv > depth8 - 4) { + assert(d > DEPTH_OFFSET); + assert(d < 256 + DEPTH_OFFSET); - key16 = (uint16_t)k; - depth8 = (uint8_t)(d - DEPTH_OFFSET); - genBound8 = (uint8_t)(TT.generation8 | uint8_t(pv) << 2 | b); - value16 = (int16_t)v; - eval16 = (int16_t)ev; - } -} + key16 = (uint16_t) k; + depth8 = (uint8_t) (d - DEPTH_OFFSET); + genBound8 = (uint8_t) (TT.generation8 | uint8_t(pv) << 2 | b); + value16 = (int16_t) v; + eval16 = (int16_t) ev; + } + } -/// TranspositionTable::resize() sets the size of the transposition table, -/// measured in megabytes. Transposition table consists of a power of 2 number -/// of clusters and each cluster consists of ClusterSize number of TTEntry. + /// TranspositionTable::resize() sets the size of the transposition table, + /// measured in megabytes. Transposition table consists of a power of 2 number + /// of clusters and each cluster consists of ClusterSize number of TTEntry. -void TranspositionTable::resize(size_t mbSize) { + void TranspositionTable::resize(size_t mbSize) { - Threads.main()->wait_for_search_finished(); + Threads.main()->wait_for_search_finished(); - aligned_large_pages_free(table); + aligned_large_pages_free(table); - clusterCount = mbSize * 1024 * 1024 / sizeof(Cluster); + clusterCount = mbSize * 1024 * 1024 / sizeof(Cluster); - table = static_cast(aligned_large_pages_alloc(clusterCount * sizeof(Cluster))); - if (!table) - { - std::cerr << "Failed to allocate " << mbSize - << "MB for transposition table." << std::endl; - exit(EXIT_FAILURE); - } + table = static_cast(aligned_large_pages_alloc(clusterCount * sizeof(Cluster))); + if (!table) { + std::cerr << "Failed to allocate " << mbSize << "MB for transposition table." + << std::endl; + exit(EXIT_FAILURE); + } - clear(); -} + clear(); + } -/// TranspositionTable::clear() initializes the entire transposition table to zero, -// in a multi-threaded way. + /// TranspositionTable::clear() initializes the entire transposition table to zero, + // in a multi-threaded way. -void TranspositionTable::clear() { + void TranspositionTable::clear() { - std::vector threads; + std::vector threads; - for (size_t idx = 0; idx < size_t(Options["Threads"]); ++idx) - { - threads.emplace_back([this, idx]() { + for (size_t idx = 0; idx < size_t(Options["Threads"]); ++idx) { + threads.emplace_back([this, idx]() { + // Thread binding gives faster search on systems with a first-touch policy + if (Options["Threads"] > 8) WinProcGroup::bindThisThread(idx); - // Thread binding gives faster search on systems with a first-touch policy - if (Options["Threads"] > 8) - WinProcGroup::bindThisThread(idx); + // Each thread will zero its part of the hash table + const size_t stride = size_t(clusterCount / Options["Threads"]), + start = size_t(stride * idx), + len = idx != size_t(Options["Threads"]) - 1 ? stride : + clusterCount - start; - // Each thread will zero its part of the hash table - const size_t stride = size_t(clusterCount / Options["Threads"]), - start = size_t(stride * idx), - len = idx != size_t(Options["Threads"]) - 1 ? - stride : clusterCount - start; + std::memset(&table[start], 0, len * sizeof(Cluster)); + }); + } - std::memset(&table[start], 0, len * sizeof(Cluster)); - }); - } + for (std::thread& th : threads) th.join(); + } - for (std::thread& th : threads) - th.join(); -} + /// TranspositionTable::probe() looks up the current position in the transposition + /// table. It returns true and a pointer to the TTEntry if the position is found. + /// Otherwise, it returns false and a pointer to an empty or least valuable TTEntry + /// to be replaced later. The replace value of an entry is calculated as its depth + /// minus 8 times its relative age. TTEntry t1 is considered more valuable than + /// TTEntry t2 if its replace value is greater than that of t2. -/// TranspositionTable::probe() looks up the current position in the transposition -/// table. It returns true and a pointer to the TTEntry if the position is found. -/// Otherwise, it returns false and a pointer to an empty or least valuable TTEntry -/// to be replaced later. The replace value of an entry is calculated as its depth -/// minus 8 times its relative age. TTEntry t1 is considered more valuable than -/// TTEntry t2 if its replace value is greater than that of t2. + TTEntry* TranspositionTable::probe(const Key key, bool& found) const { -TTEntry* TranspositionTable::probe(const Key key, bool& found) const { + TTEntry* const tte = first_entry(key); + const uint16_t key16 = (uint16_t) key; // Use the low 16 bits as key inside the cluster - TTEntry* const tte = first_entry(key); - const uint16_t key16 = (uint16_t)key; // Use the low 16 bits as key inside the cluster + for (int i = 0; i < ClusterSize; ++i) + if (tte[i].key16 == key16 || !tte[i].depth8) { + tte[i].genBound8 = + uint8_t(generation8 | (tte[i].genBound8 & (GENERATION_DELTA - 1))); // Refresh - for (int i = 0; i < ClusterSize; ++i) - if (tte[i].key16 == key16 || !tte[i].depth8) - { - tte[i].genBound8 = uint8_t(generation8 | (tte[i].genBound8 & (GENERATION_DELTA - 1))); // Refresh + return found = (bool) tte[i].depth8, &tte[i]; + } - return found = (bool)tte[i].depth8, &tte[i]; - } + // Find an entry to be replaced according to the replacement strategy + TTEntry* replace = tte; + for (int i = 1; i < ClusterSize; ++i) + // Due to our packed storage format for generation and its cyclic + // nature we add GENERATION_CYCLE (256 is the modulus, plus what + // is needed to keep the unrelated lowest n bits from affecting + // the result) to calculate the entry age correctly even after + // generation8 overflows into the next cycle. + if (replace->depth8 - + ((GENERATION_CYCLE + generation8 - replace->genBound8) & GENERATION_MASK) > + tte[i].depth8 - + ((GENERATION_CYCLE + generation8 - tte[i].genBound8) & GENERATION_MASK)) + replace = &tte[i]; - // Find an entry to be replaced according to the replacement strategy - TTEntry* replace = tte; - for (int i = 1; i < ClusterSize; ++i) - // Due to our packed storage format for generation and its cyclic - // nature we add GENERATION_CYCLE (256 is the modulus, plus what - // is needed to keep the unrelated lowest n bits from affecting - // the result) to calculate the entry age correctly even after - // generation8 overflows into the next cycle. - if ( replace->depth8 - ((GENERATION_CYCLE + generation8 - replace->genBound8) & GENERATION_MASK) - > tte[i].depth8 - ((GENERATION_CYCLE + generation8 - tte[i].genBound8) & GENERATION_MASK)) - replace = &tte[i]; + return found = false, replace; + } - return found = false, replace; -} + /// TranspositionTable::hashfull() returns an approximation of the hashtable + /// occupation during a search. The hash is x permill full, as per UCI protocol. -/// TranspositionTable::hashfull() returns an approximation of the hashtable -/// occupation during a search. The hash is x permill full, as per UCI protocol. + int TranspositionTable::hashfull() const { -int TranspositionTable::hashfull() const { + int cnt = 0; + for (int i = 0; i < 1000; ++i) + for (int j = 0; j < ClusterSize; ++j) + cnt += table[i].entry[j].depth8 && + (table[i].entry[j].genBound8 & GENERATION_MASK) == generation8; - int cnt = 0; - for (int i = 0; i < 1000; ++i) - for (int j = 0; j < ClusterSize; ++j) - cnt += table[i].entry[j].depth8 && (table[i].entry[j].genBound8 & GENERATION_MASK) == generation8; + return cnt / ClusterSize; + } - return cnt / ClusterSize; -} - -} // namespace Stockfish +} // namespace Stockfish diff --git a/src/tt.h b/src/tt.h index df962faaa7b..fc2ddb8291b 100644 --- a/src/tt.h +++ b/src/tt.h @@ -27,84 +27,88 @@ namespace Stockfish { -/// TTEntry struct is the 10 bytes transposition table entry, defined as below: -/// -/// key 16 bit -/// depth 8 bit -/// generation 5 bit -/// pv node 1 bit -/// bound type 2 bit -/// move 16 bit -/// value 16 bit -/// eval value 16 bit - -struct TTEntry { - - Move move() const { return (Move )move16; } - Value value() const { return (Value)value16; } - Value eval() const { return (Value)eval16; } - Depth depth() const { return (Depth)depth8 + DEPTH_OFFSET; } - bool is_pv() const { return (bool)(genBound8 & 0x4); } - Bound bound() const { return (Bound)(genBound8 & 0x3); } - void save(Key k, Value v, bool pv, Bound b, Depth d, Move m, Value ev); - -private: - friend class TranspositionTable; - - uint16_t key16; - uint8_t depth8; - uint8_t genBound8; - uint16_t move16; - int16_t value16; - int16_t eval16; -}; - - -/// A TranspositionTable is an array of Cluster, of size clusterCount. Each -/// cluster consists of ClusterSize number of TTEntry. Each non-empty TTEntry -/// contains information on exactly one position. The size of a Cluster should -/// divide the size of a cache line for best performance, as the cacheline is -/// prefetched when possible. - -class TranspositionTable { - - static constexpr int ClusterSize = 3; - - struct Cluster { - TTEntry entry[ClusterSize]; - char padding[2]; // Pad to 32 bytes - }; - - static_assert(sizeof(Cluster) == 32, "Unexpected Cluster size"); - - // Constants used to refresh the hash table periodically - static constexpr unsigned GENERATION_BITS = 3; // nb of bits reserved for other things - static constexpr int GENERATION_DELTA = (1 << GENERATION_BITS); // increment for generation field - static constexpr int GENERATION_CYCLE = 255 + (1 << GENERATION_BITS); // cycle length - static constexpr int GENERATION_MASK = (0xFF << GENERATION_BITS) & 0xFF; // mask to pull out generation number - -public: - ~TranspositionTable() { aligned_large_pages_free(table); } - void new_search() { generation8 += GENERATION_DELTA; } // Lower bits are used for other things - TTEntry* probe(const Key key, bool& found) const; - int hashfull() const; - void resize(size_t mbSize); - void clear(); - - TTEntry* first_entry(const Key key) const { - return &table[mul_hi64(key, clusterCount)].entry[0]; - } - -private: - friend struct TTEntry; - - size_t clusterCount; - Cluster* table; - uint8_t generation8; // Size must be not bigger than TTEntry::genBound8 -}; - -extern TranspositionTable TT; - -} // namespace Stockfish - -#endif // #ifndef TT_H_INCLUDED + /// TTEntry struct is the 10 bytes transposition table entry, defined as below: + /// + /// key 16 bit + /// depth 8 bit + /// generation 5 bit + /// pv node 1 bit + /// bound type 2 bit + /// move 16 bit + /// value 16 bit + /// eval value 16 bit + + struct TTEntry { + + Move move() const { return (Move) move16; } + Value value() const { return (Value) value16; } + Value eval() const { return (Value) eval16; } + Depth depth() const { return (Depth) depth8 + DEPTH_OFFSET; } + bool is_pv() const { return (bool) (genBound8 & 0x4); } + Bound bound() const { return (Bound) (genBound8 & 0x3); } + void save(Key k, Value v, bool pv, Bound b, Depth d, Move m, Value ev); + + private: + friend class TranspositionTable; + + uint16_t key16; + uint8_t depth8; + uint8_t genBound8; + uint16_t move16; + int16_t value16; + int16_t eval16; + }; + + + /// A TranspositionTable is an array of Cluster, of size clusterCount. Each + /// cluster consists of ClusterSize number of TTEntry. Each non-empty TTEntry + /// contains information on exactly one position. The size of a Cluster should + /// divide the size of a cache line for best performance, as the cacheline is + /// prefetched when possible. + + class TranspositionTable { + + static constexpr int ClusterSize = 3; + + struct Cluster { + TTEntry entry[ClusterSize]; + char padding[2]; // Pad to 32 bytes + }; + + static_assert(sizeof(Cluster) == 32, "Unexpected Cluster size"); + + // Constants used to refresh the hash table periodically + static constexpr unsigned GENERATION_BITS = 3; // nb of bits reserved for other things + static constexpr int GENERATION_DELTA = + (1 << GENERATION_BITS); // increment for generation field + static constexpr int GENERATION_CYCLE = 255 + (1 << GENERATION_BITS); // cycle length + static constexpr int GENERATION_MASK = + (0xFF << GENERATION_BITS) & 0xFF; // mask to pull out generation number + + public: + ~TranspositionTable() { aligned_large_pages_free(table); } + void new_search() { + generation8 += GENERATION_DELTA; + } // Lower bits are used for other things + TTEntry* probe(const Key key, bool& found) const; + int hashfull() const; + void resize(size_t mbSize); + void clear(); + + TTEntry* first_entry(const Key key) const { + return &table[mul_hi64(key, clusterCount)].entry[0]; + } + + private: + friend struct TTEntry; + + size_t clusterCount; + Cluster* table; + uint8_t generation8; // Size must be not bigger than TTEntry::genBound8 + }; + + extern TranspositionTable TT; + +} // namespace Stockfish + +#endif // #ifndef TT_H_INCLUDED diff --git a/src/tune.cpp b/src/tune.cpp index 97baeb784e9..229916fa7cd 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -27,82 +27,73 @@ #include "uci.h" namespace Stockfish { -enum Value : int; + enum Value : int; } using std::string; namespace Stockfish { -bool Tune::update_on_last; -const UCI::Option* LastOption = nullptr; -static std::map TuneResults; + bool Tune::update_on_last; + const UCI::Option* LastOption = nullptr; + static std::map TuneResults; -string Tune::next(string& names, bool pop) { + string Tune::next(string& names, bool pop) { - string name; + string name; - do { - string token = names.substr(0, names.find(',')); + do { + string token = names.substr(0, names.find(',')); - if (pop) - names.erase(0, token.size() + 1); + if (pop) names.erase(0, token.size() + 1); - std::stringstream ws(token); - name += (ws >> token, token); // Remove trailing whitespace + std::stringstream ws(token); + name += (ws >> token, token); // Remove trailing whitespace - } while ( std::count(name.begin(), name.end(), '(') - - std::count(name.begin(), name.end(), ')')); + } while (std::count(name.begin(), name.end(), '(') - + std::count(name.begin(), name.end(), ')')); - return name; -} + return name; + } -static void on_tune(const UCI::Option& o) { + static void on_tune(const UCI::Option& o) { - if (!Tune::update_on_last || LastOption == &o) - Tune::read_options(); -} + if (!Tune::update_on_last || LastOption == &o) Tune::read_options(); + } -static void make_option(const string& n, int v, const SetRange& r) { + static void make_option(const string& n, int v, const SetRange& r) { - // Do not generate option when there is nothing to tune (ie. min = max) - if (r(v).first == r(v).second) - return; + // Do not generate option when there is nothing to tune (ie. min = max) + if (r(v).first == r(v).second) return; - if (TuneResults.count(n)) - v = TuneResults[n]; + if (TuneResults.count(n)) v = TuneResults[n]; - Options[n] << UCI::Option(v, r(v).first, r(v).second, on_tune); - LastOption = &Options[n]; + Options[n] << UCI::Option(v, r(v).first, r(v).second, on_tune); + LastOption = &Options[n]; - // Print formatted parameters, ready to be copy-pasted in Fishtest - std::cout << n << "," - << v << "," - << r(v).first << "," << r(v).second << "," - << (r(v).second - r(v).first) / 20.0 << "," - << "0.0020" - << std::endl; -} + // Print formatted parameters, ready to be copy-pasted in Fishtest + std::cout << n << "," << v << "," << r(v).first << "," << r(v).second << "," + << (r(v).second - r(v).first) / 20.0 << "," + << "0.0020" << std::endl; + } -template<> void Tune::Entry::init_option() { make_option(name, value, range); } + template<> void Tune::Entry::init_option() { make_option(name, value, range); } -template<> void Tune::Entry::read_option() { - if (Options.count(name)) - value = int(Options[name]); -} + template<> void Tune::Entry::read_option() { + if (Options.count(name)) value = int(Options[name]); + } -template<> void Tune::Entry::init_option() { make_option(name, value, range); } + template<> void Tune::Entry::init_option() { make_option(name, value, range); } -template<> void Tune::Entry::read_option() { - if (Options.count(name)) - value = Value(int(Options[name])); -} + template<> void Tune::Entry::read_option() { + if (Options.count(name)) value = Value(int(Options[name])); + } -// Instead of a variable here we have a PostUpdate function: just call it -template<> void Tune::Entry::init_option() {} -template<> void Tune::Entry::read_option() { value(); } + // Instead of a variable here we have a PostUpdate function: just call it + template<> void Tune::Entry::init_option() {} + template<> void Tune::Entry::read_option() { value(); } -} // namespace Stockfish +} // namespace Stockfish // Init options with tuning session results instead of default values. Useful to @@ -117,9 +108,7 @@ template<> void Tune::Entry::read_option() { value(); } namespace Stockfish { -void Tune::read_results() { - - /* ...insert your values here... */ -} + void Tune::read_results() { /* ...insert your values here... */ + } -} // namespace Stockfish +} // namespace Stockfish diff --git a/src/tune.h b/src/tune.h index 3e94f7efc6c..fd210e40141 100644 --- a/src/tune.h +++ b/src/tune.h @@ -27,138 +27,143 @@ #include namespace Stockfish { -enum Value : int; + enum Value : int; -using Range = std::pair; // Option's min-max values -using RangeFun = Range (int); + using Range = std::pair; // Option's min-max values + using RangeFun = Range(int); -// Default Range function, to calculate Option's min-max values -inline Range default_range(int v) { - return v > 0 ? Range(0, 2 * v) : Range(2 * v, 0); -} + // Default Range function, to calculate Option's min-max values + inline Range default_range(int v) { return v > 0 ? Range(0, 2 * v) : Range(2 * v, 0); } -struct SetRange { - explicit SetRange(RangeFun f) : fun(f) {} - SetRange(int min, int max) : fun(nullptr), range(min, max) {} - Range operator()(int v) const { return fun ? fun(v) : range; } + struct SetRange { + explicit SetRange(RangeFun f) : fun(f) {} + SetRange(int min, int max) : fun(nullptr), range(min, max) {} + Range operator()(int v) const { return fun ? fun(v) : range; } - RangeFun* fun; - Range range; -}; + RangeFun* fun; + Range range; + }; #define SetDefaultRange SetRange(default_range) -/// Tune class implements the 'magic' code that makes the setup of a fishtest -/// tuning session as easy as it can be. Mainly you have just to remove const -/// qualifiers from the variables you want to tune and flag them for tuning, so -/// if you have: -/// -/// const Value myValue[][2] = { { V(100), V(20) }, { V(7), V(78) } }; -/// -/// If you have a my_post_update() function to run after values have been updated, -/// and a my_range() function to set custom Option's min-max values, then you just -/// remove the 'const' qualifiers and write somewhere below in the file: -/// -/// TUNE(SetRange(my_range), myValue, my_post_update); -/// -/// You can also set the range directly, and restore the default at the end -/// -/// TUNE(SetRange(-100, 100), myValue, SetDefaultRange); -/// -/// In case update function is slow and you have many parameters, you can add: -/// -/// UPDATE_ON_LAST(); -/// -/// And the values update, including post update function call, will be done only -/// once, after the engine receives the last UCI option, that is the one defined -/// and created as the last one, so the GUI should send the options in the same -/// order in which have been defined. - -class Tune { - - using PostUpdate = void (); // Post-update function - - Tune() { read_results(); } - Tune(const Tune&) = delete; - void operator=(const Tune&) = delete; - void read_results(); - - static Tune& instance() { static Tune t; return t; } // Singleton - - // Use polymorphism to accommodate Entry of different types in the same vector - struct EntryBase { - virtual ~EntryBase() = default; - virtual void init_option() = 0; - virtual void read_option() = 0; - }; - - template - struct Entry : public EntryBase { - - static_assert(!std::is_const::value, "Parameter cannot be const!"); - - static_assert( std::is_same::value - || std::is_same::value - || std::is_same::value, "Parameter type not supported!"); - - Entry(const std::string& n, T& v, const SetRange& r) : name(n), value(v), range(r) {} - void operator=(const Entry&) = delete; // Because 'value' is a reference - void init_option() override; - void read_option() override; - - std::string name; - T& value; - SetRange range; - }; - - // Our facility to fill the container, each Entry corresponds to a parameter - // to tune. We use variadic templates to deal with an unspecified number of - // entries, each one of a possible different type. - static std::string next(std::string& names, bool pop = true); - - int add(const SetRange&, std::string&&) { return 0; } - - template - int add(const SetRange& range, std::string&& names, T& value, Args&&... args) { - list.push_back(std::unique_ptr(new Entry(next(names), value, range))); - return add(range, std::move(names), args...); - } - - // Template specialization for arrays: recursively handle multi-dimensional arrays - template - int add(const SetRange& range, std::string&& names, T (&value)[N], Args&&... args) { - for (size_t i = 0; i < N; i++) - add(range, next(names, i == N - 1) + "[" + std::to_string(i) + "]", value[i]); - return add(range, std::move(names), args...); - } - - // Template specialization for SetRange - template - int add(const SetRange&, std::string&& names, SetRange& value, Args&&... args) { - return add(value, (next(names), std::move(names)), args...); - } - - std::vector> list; - -public: - template - static int add(const std::string& names, Args&&... args) { - return instance().add(SetDefaultRange, names.substr(1, names.size() - 2), args...); // Remove trailing parenthesis - } - static void init() { for (auto& e : instance().list) e->init_option(); read_options(); } // Deferred, due to UCI::Options access - static void read_options() { for (auto& e : instance().list) e->read_option(); } - static bool update_on_last; -}; + /// Tune class implements the 'magic' code that makes the setup of a fishtest + /// tuning session as easy as it can be. Mainly you have just to remove const + /// qualifiers from the variables you want to tune and flag them for tuning, so + /// if you have: + /// + /// const Value myValue[][2] = { { V(100), V(20) }, { V(7), V(78) } }; + /// + /// If you have a my_post_update() function to run after values have been updated, + /// and a my_range() function to set custom Option's min-max values, then you just + /// remove the 'const' qualifiers and write somewhere below in the file: + /// + /// TUNE(SetRange(my_range), myValue, my_post_update); + /// + /// You can also set the range directly, and restore the default at the end + /// + /// TUNE(SetRange(-100, 100), myValue, SetDefaultRange); + /// + /// In case update function is slow and you have many parameters, you can add: + /// + /// UPDATE_ON_LAST(); + /// + /// And the values update, including post update function call, will be done only + /// once, after the engine receives the last UCI option, that is the one defined + /// and created as the last one, so the GUI should send the options in the same + /// order in which have been defined. + + class Tune { + + using PostUpdate = void(); // Post-update function + + Tune() { read_results(); } + Tune(const Tune&) = delete; + void operator=(const Tune&) = delete; + void read_results(); + + static Tune& instance() { + static Tune t; + return t; + } // Singleton + + // Use polymorphism to accommodate Entry of different types in the same vector + struct EntryBase { + virtual ~EntryBase() = default; + virtual void init_option() = 0; + virtual void read_option() = 0; + }; + + template struct Entry: public EntryBase { + + static_assert(!std::is_const::value, "Parameter cannot be const!"); + + static_assert(std::is_same::value || std::is_same::value || + std::is_same::value, + "Parameter type not supported!"); + + Entry(const std::string& n, T& v, const SetRange& r) : name(n), value(v), range(r) {} + void operator=(const Entry&) = delete; // Because 'value' is a reference + void init_option() override; + void read_option() override; + + std::string name; + T& value; + SetRange range; + }; + + // Our facility to fill the container, each Entry corresponds to a parameter + // to tune. We use variadic templates to deal with an unspecified number of + // entries, each one of a possible different type. + static std::string next(std::string& names, bool pop = true); + + int add(const SetRange&, std::string&&) { return 0; } + + template + int add(const SetRange& range, std::string&& names, T& value, Args&&... args) { + list.push_back(std::unique_ptr(new Entry(next(names), value, range))); + return add(range, std::move(names), args...); + } + + // Template specialization for arrays: recursively handle multi-dimensional arrays + template + int add(const SetRange& range, std::string&& names, T (&value)[N], Args&&... args) { + for (size_t i = 0; i < N; i++) + add(range, next(names, i == N - 1) + "[" + std::to_string(i) + "]", value[i]); + return add(range, std::move(names), args...); + } + + // Template specialization for SetRange + template + int add(const SetRange&, std::string&& names, SetRange& value, Args&&... args) { + return add(value, (next(names), std::move(names)), args...); + } + + std::vector> list; + + public: + template static int add(const std::string& names, Args&&... args) { + return instance().add(SetDefaultRange, names.substr(1, names.size() - 2), + args...); // Remove trailing parenthesis + } + static void init() { + for (auto& e : instance().list) e->init_option(); + read_options(); + } // Deferred, due to UCI::Options access + static void read_options() { + for (auto& e : instance().list) e->read_option(); + } + static bool update_on_last; + }; // Some macro magic :-) we define a dummy int variable that compiler initializes calling Tune::add() #define STRINGIFY(x) #x -#define UNIQUE2(x, y) x ## y -#define UNIQUE(x, y) UNIQUE2(x, y) // Two indirection levels to expand __LINE__ +#define UNIQUE2(x, y) x##y +#define UNIQUE(x, y) UNIQUE2(x, y) // Two indirection levels to expand __LINE__ #define TUNE(...) int UNIQUE(p, __LINE__) = Tune::add(STRINGIFY((__VA_ARGS__)), __VA_ARGS__) #define UPDATE_ON_LAST() bool UNIQUE(p, __LINE__) = Tune::update_on_last = true -} // namespace Stockfish +} // namespace Stockfish -#endif // #ifndef TUNE_H_INCLUDED +#endif // #ifndef TUNE_H_INCLUDED diff --git a/src/types.h b/src/types.h index f81d30fe032..7016e4f8b0f 100644 --- a/src/types.h +++ b/src/types.h @@ -17,7 +17,7 @@ */ #ifndef TYPES_H_INCLUDED -#define TYPES_H_INCLUDED + #define TYPES_H_INCLUDED /// When compiling with provided Makefile (e.g. for Linux and OSX), configuration /// is done automatically. To get started type 'make help'. @@ -36,15 +36,15 @@ /// -DUSE_PEXT | Add runtime support for use of pext asm-instruction. Works /// | only in 64-bit mode and requires hardware with pext support. -#include -#include + #include + #include -#if defined(_MSC_VER) -// Disable some silly and noisy warning from MSVC compiler -#pragma warning(disable: 4127) // Conditional expression is constant -#pragma warning(disable: 4146) // Unary minus operator applied to unsigned type -#pragma warning(disable: 4800) // Forcing value to bool 'true' or 'false' -#endif + #if defined(_MSC_VER) + // Disable some silly and noisy warning from MSVC compiler + #pragma warning(disable: 4127) // Conditional expression is constant + #pragma warning(disable: 4146) // Unary minus operator applied to unsigned type + #pragma warning(disable: 4800) // Forcing value to bool 'true' or 'false' + #endif /// Predefined macros hell: /// @@ -55,364 +55,422 @@ /// _WIN32 Building on Windows (any) /// _WIN64 Building on Windows 64 bit -#if defined(__GNUC__ ) && (__GNUC__ < 9 || (__GNUC__ == 9 && __GNUC_MINOR__ <= 2)) && defined(_WIN32) && !defined(__clang__) -#define ALIGNAS_ON_STACK_VARIABLES_BROKEN -#endif + #if defined(__GNUC__) && (__GNUC__ < 9 || (__GNUC__ == 9 && __GNUC_MINOR__ <= 2)) && \ + defined(_WIN32) && !defined(__clang__) + #define ALIGNAS_ON_STACK_VARIABLES_BROKEN + #endif -#define ASSERT_ALIGNED(ptr, alignment) assert(reinterpret_cast(ptr) % alignment == 0) + #define ASSERT_ALIGNED(ptr, alignment) assert(reinterpret_cast(ptr) % alignment == 0) -#if defined(_WIN64) && defined(_MSC_VER) // No Makefile used -# include // Microsoft header for _BitScanForward64() -# define IS_64BIT -#endif + #if defined(_WIN64) && defined(_MSC_VER) // No Makefile used + #include // Microsoft header for _BitScanForward64() + #define IS_64BIT + #endif -#if defined(USE_POPCNT) && defined(_MSC_VER) -# include // Microsoft header for _mm_popcnt_u64() -#endif + #if defined(USE_POPCNT) && defined(_MSC_VER) + #include // Microsoft header for _mm_popcnt_u64() + #endif -#if !defined(NO_PREFETCH) && defined(_MSC_VER) -# include // Microsoft header for _mm_prefetch() -#endif + #if !defined(NO_PREFETCH) && defined(_MSC_VER) + #include // Microsoft header for _mm_prefetch() + #endif -#if defined(USE_PEXT) -# include // Header for _pext_u64() intrinsic -# define pext(b, m) _pext_u64(b, m) -#else -# define pext(b, m) 0 -#endif + #if defined(USE_PEXT) + #include // Header for _pext_u64() intrinsic + #define pext(b, m) _pext_u64(b, m) + #else + #define pext(b, m) 0 + #endif namespace Stockfish { -#ifdef USE_POPCNT -constexpr bool HasPopCnt = true; -#else -constexpr bool HasPopCnt = false; -#endif + #ifdef USE_POPCNT + constexpr bool HasPopCnt = true; + #else + constexpr bool HasPopCnt = false; + #endif + + #ifdef USE_PEXT + constexpr bool HasPext = true; + #else + constexpr bool HasPext = false; + #endif + + #ifdef IS_64BIT + constexpr bool Is64Bit = true; + #else + constexpr bool Is64Bit = false; + #endif + + using Key = uint64_t; + using Bitboard = uint64_t; + + constexpr int MAX_MOVES = 256; + constexpr int MAX_PLY = 246; + + /// A move needs 16 bits to be stored + /// + /// bit 0- 5: destination square (from 0 to 63) + /// bit 6-11: origin square (from 0 to 63) + /// bit 12-13: promotion piece type - 2 (from KNIGHT-2 to QUEEN-2) + /// bit 14-15: special move flag: promotion (1), en passant (2), castling (3) + /// NOTE: en passant bit is set only when a pawn can be captured + /// + /// Special cases are MOVE_NONE and MOVE_NULL. We can sneak these in because in + /// any normal move destination square is always different from origin square + /// while MOVE_NONE and MOVE_NULL have the same origin and destination square. + + enum Move : int { + MOVE_NONE, + MOVE_NULL = 65 + }; + + enum MoveType { + NORMAL, + PROMOTION = 1 << 14, + EN_PASSANT = 2 << 14, + CASTLING = 3 << 14 + }; + + enum Color { + WHITE, + BLACK, + COLOR_NB = 2 + }; + + enum CastlingRights { + NO_CASTLING, + WHITE_OO, + WHITE_OOO = WHITE_OO << 1, + BLACK_OO = WHITE_OO << 2, + BLACK_OOO = WHITE_OO << 3, + + KING_SIDE = WHITE_OO | BLACK_OO, + QUEEN_SIDE = WHITE_OOO | BLACK_OOO, + WHITE_CASTLING = WHITE_OO | WHITE_OOO, + BLACK_CASTLING = BLACK_OO | BLACK_OOO, + ANY_CASTLING = WHITE_CASTLING | BLACK_CASTLING, + + CASTLING_RIGHT_NB = 16 + }; + + enum Bound { + BOUND_NONE, + BOUND_UPPER, + BOUND_LOWER, + BOUND_EXACT = BOUND_UPPER | BOUND_LOWER + }; + + enum Value : int { + VALUE_ZERO = 0, + VALUE_DRAW = 0, + VALUE_KNOWN_WIN = 10000, + VALUE_MATE = 32000, + VALUE_INFINITE = 32001, + VALUE_NONE = 32002, + + VALUE_TB_WIN_IN_MAX_PLY = VALUE_MATE - 2 * MAX_PLY, + VALUE_TB_LOSS_IN_MAX_PLY = -VALUE_TB_WIN_IN_MAX_PLY, + VALUE_MATE_IN_MAX_PLY = VALUE_MATE - MAX_PLY, + VALUE_MATED_IN_MAX_PLY = -VALUE_MATE_IN_MAX_PLY, + + // In the code, we make the assumption that these values + // are such that non_pawn_material() can be used to uniquely + // identify the material on the board. + PawnValue = 208, + KnightValue = 781, + BishopValue = 825, + RookValue = 1276, + QueenValue = 2538, + }; + + enum PieceType { + NO_PIECE_TYPE, + PAWN, + KNIGHT, + BISHOP, + ROOK, + QUEEN, + KING, + ALL_PIECES = 0, + PIECE_TYPE_NB = 8 + }; + + enum Piece { + NO_PIECE, + W_PAWN = PAWN, + W_KNIGHT, + W_BISHOP, + W_ROOK, + W_QUEEN, + W_KING, + B_PAWN = PAWN + 8, + B_KNIGHT, + B_BISHOP, + B_ROOK, + B_QUEEN, + B_KING, + PIECE_NB = 16 + }; + + constexpr Value PieceValue[PIECE_NB] = {VALUE_ZERO, PawnValue, KnightValue, BishopValue, + RookValue, QueenValue, VALUE_ZERO, VALUE_ZERO, + VALUE_ZERO, PawnValue, KnightValue, BishopValue, + RookValue, QueenValue, VALUE_ZERO, VALUE_ZERO}; + + using Depth = int; + + enum : int { + DEPTH_QS_CHECKS = 0, + DEPTH_QS_NO_CHECKS = -1, + DEPTH_QS_RECAPTURES = -5, + + DEPTH_NONE = -6, + + DEPTH_OFFSET = -7 // value used only for TT entry occupancy check + }; + + enum Square : int { + SQ_A1, + SQ_B1, + SQ_C1, + SQ_D1, + SQ_E1, + SQ_F1, + SQ_G1, + SQ_H1, + SQ_A2, + SQ_B2, + SQ_C2, + SQ_D2, + SQ_E2, + SQ_F2, + SQ_G2, + SQ_H2, + SQ_A3, + SQ_B3, + SQ_C3, + SQ_D3, + SQ_E3, + SQ_F3, + SQ_G3, + SQ_H3, + SQ_A4, + SQ_B4, + SQ_C4, + SQ_D4, + SQ_E4, + SQ_F4, + SQ_G4, + SQ_H4, + SQ_A5, + SQ_B5, + SQ_C5, + SQ_D5, + SQ_E5, + SQ_F5, + SQ_G5, + SQ_H5, + SQ_A6, + SQ_B6, + SQ_C6, + SQ_D6, + SQ_E6, + SQ_F6, + SQ_G6, + SQ_H6, + SQ_A7, + SQ_B7, + SQ_C7, + SQ_D7, + SQ_E7, + SQ_F7, + SQ_G7, + SQ_H7, + SQ_A8, + SQ_B8, + SQ_C8, + SQ_D8, + SQ_E8, + SQ_F8, + SQ_G8, + SQ_H8, + SQ_NONE, + + SQUARE_ZERO = 0, + SQUARE_NB = 64 + }; + + enum Direction : int { + NORTH = 8, + EAST = 1, + SOUTH = -NORTH, + WEST = -EAST, + + NORTH_EAST = NORTH + EAST, + SOUTH_EAST = SOUTH + EAST, + SOUTH_WEST = SOUTH + WEST, + NORTH_WEST = NORTH + WEST + }; + + enum File : int { + FILE_A, + FILE_B, + FILE_C, + FILE_D, + FILE_E, + FILE_F, + FILE_G, + FILE_H, + FILE_NB + }; + + enum Rank : int { + RANK_1, + RANK_2, + RANK_3, + RANK_4, + RANK_5, + RANK_6, + RANK_7, + RANK_8, + RANK_NB + }; + + // Keep track of what a move changes on the board (used by NNUE) + struct DirtyPiece { + + // Number of changed pieces + int dirty_num; + + // Max 3 pieces can change in one move. A promotion with capture moves + // both the pawn and the captured piece to SQ_NONE and the piece promoted + // to from SQ_NONE to the capture square. + Piece piece[3]; + + // From and to squares, which may be SQ_NONE + Square from[3]; + Square to[3]; + }; + + #define ENABLE_BASE_OPERATORS_ON(T) \ + constexpr T operator+(T d1, int d2) { return T(int(d1) + d2); } \ + constexpr T operator-(T d1, int d2) { return T(int(d1) - d2); } \ + constexpr T operator-(T d) { return T(-int(d)); } \ + inline T& operator+=(T& d1, int d2) { return d1 = d1 + d2; } \ + inline T& operator-=(T& d1, int d2) { return d1 = d1 - d2; } + + #define ENABLE_INCR_OPERATORS_ON(T) \ + inline T& operator++(T& d) { return d = T(int(d) + 1); } \ + inline T& operator--(T& d) { return d = T(int(d) - 1); } + + #define ENABLE_FULL_OPERATORS_ON(T) \ + ENABLE_BASE_OPERATORS_ON(T) \ + constexpr T operator*(int i, T d) { return T(i * int(d)); } \ + constexpr T operator*(T d, int i) { return T(int(d) * i); } \ + constexpr T operator/(T d, int i) { return T(int(d) / i); } \ + constexpr int operator/(T d1, T d2) { return int(d1) / int(d2); } \ + inline T& operator*=(T& d, int i) { return d = T(int(d) * i); } \ + inline T& operator/=(T& d, int i) { return d = T(int(d) / i); } + + ENABLE_FULL_OPERATORS_ON(Value) + ENABLE_FULL_OPERATORS_ON(Direction) + + ENABLE_INCR_OPERATORS_ON(PieceType) + ENABLE_INCR_OPERATORS_ON(Square) + ENABLE_INCR_OPERATORS_ON(File) + ENABLE_INCR_OPERATORS_ON(Rank) + + #undef ENABLE_FULL_OPERATORS_ON + #undef ENABLE_INCR_OPERATORS_ON + #undef ENABLE_BASE_OPERATORS_ON + + /// Additional operators to add a Direction to a Square + constexpr Square operator+(Square s, Direction d) { return Square(int(s) + int(d)); } + constexpr Square operator-(Square s, Direction d) { return Square(int(s) - int(d)); } + inline Square& operator+=(Square& s, Direction d) { return s = s + d; } + inline Square& operator-=(Square& s, Direction d) { return s = s - d; } -#ifdef USE_PEXT -constexpr bool HasPext = true; -#else -constexpr bool HasPext = false; -#endif + constexpr Color operator~(Color c) { + return Color(c ^ BLACK); // Toggle color + } -#ifdef IS_64BIT -constexpr bool Is64Bit = true; -#else -constexpr bool Is64Bit = false; -#endif + constexpr Square flip_rank(Square s) { // Swap A1 <-> A8 + return Square(s ^ SQ_A8); + } -using Key = uint64_t; -using Bitboard = uint64_t; + constexpr Square flip_file(Square s) { // Swap A1 <-> H1 + return Square(s ^ SQ_H1); + } -constexpr int MAX_MOVES = 256; -constexpr int MAX_PLY = 246; + constexpr Piece operator~(Piece pc) { + return Piece(pc ^ 8); // Swap color of piece B_KNIGHT <-> W_KNIGHT + } -/// A move needs 16 bits to be stored -/// -/// bit 0- 5: destination square (from 0 to 63) -/// bit 6-11: origin square (from 0 to 63) -/// bit 12-13: promotion piece type - 2 (from KNIGHT-2 to QUEEN-2) -/// bit 14-15: special move flag: promotion (1), en passant (2), castling (3) -/// NOTE: en passant bit is set only when a pawn can be captured -/// -/// Special cases are MOVE_NONE and MOVE_NULL. We can sneak these in because in -/// any normal move destination square is always different from origin square -/// while MOVE_NONE and MOVE_NULL have the same origin and destination square. - -enum Move : int { - MOVE_NONE, - MOVE_NULL = 65 -}; - -enum MoveType { - NORMAL, - PROMOTION = 1 << 14, - EN_PASSANT = 2 << 14, - CASTLING = 3 << 14 -}; - -enum Color { - WHITE, BLACK, COLOR_NB = 2 -}; - -enum CastlingRights { - NO_CASTLING, - WHITE_OO, - WHITE_OOO = WHITE_OO << 1, - BLACK_OO = WHITE_OO << 2, - BLACK_OOO = WHITE_OO << 3, - - KING_SIDE = WHITE_OO | BLACK_OO, - QUEEN_SIDE = WHITE_OOO | BLACK_OOO, - WHITE_CASTLING = WHITE_OO | WHITE_OOO, - BLACK_CASTLING = BLACK_OO | BLACK_OOO, - ANY_CASTLING = WHITE_CASTLING | BLACK_CASTLING, - - CASTLING_RIGHT_NB = 16 -}; - -enum Bound { - BOUND_NONE, - BOUND_UPPER, - BOUND_LOWER, - BOUND_EXACT = BOUND_UPPER | BOUND_LOWER -}; - -enum Value : int { - VALUE_ZERO = 0, - VALUE_DRAW = 0, - VALUE_KNOWN_WIN = 10000, - VALUE_MATE = 32000, - VALUE_INFINITE = 32001, - VALUE_NONE = 32002, - - VALUE_TB_WIN_IN_MAX_PLY = VALUE_MATE - 2 * MAX_PLY, - VALUE_TB_LOSS_IN_MAX_PLY = -VALUE_TB_WIN_IN_MAX_PLY, - VALUE_MATE_IN_MAX_PLY = VALUE_MATE - MAX_PLY, - VALUE_MATED_IN_MAX_PLY = -VALUE_MATE_IN_MAX_PLY, - - // In the code, we make the assumption that these values - // are such that non_pawn_material() can be used to uniquely - // identify the material on the board. - PawnValue = 208, - KnightValue = 781, - BishopValue = 825, - RookValue = 1276, - QueenValue = 2538, -}; - -enum PieceType { - NO_PIECE_TYPE, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING, - ALL_PIECES = 0, - PIECE_TYPE_NB = 8 -}; - -enum Piece { - NO_PIECE, - W_PAWN = PAWN, W_KNIGHT, W_BISHOP, W_ROOK, W_QUEEN, W_KING, - B_PAWN = PAWN + 8, B_KNIGHT, B_BISHOP, B_ROOK, B_QUEEN, B_KING, - PIECE_NB = 16 -}; - -constexpr Value PieceValue[PIECE_NB] = { VALUE_ZERO, PawnValue, KnightValue, BishopValue, RookValue, QueenValue, VALUE_ZERO, VALUE_ZERO, - VALUE_ZERO, PawnValue, KnightValue, BishopValue, RookValue, QueenValue, VALUE_ZERO, VALUE_ZERO }; - -using Depth = int; - -enum : int { - DEPTH_QS_CHECKS = 0, - DEPTH_QS_NO_CHECKS = -1, - DEPTH_QS_RECAPTURES = -5, - - DEPTH_NONE = -6, - - DEPTH_OFFSET = -7 // value used only for TT entry occupancy check -}; - -enum Square : int { - SQ_A1, SQ_B1, SQ_C1, SQ_D1, SQ_E1, SQ_F1, SQ_G1, SQ_H1, - SQ_A2, SQ_B2, SQ_C2, SQ_D2, SQ_E2, SQ_F2, SQ_G2, SQ_H2, - SQ_A3, SQ_B3, SQ_C3, SQ_D3, SQ_E3, SQ_F3, SQ_G3, SQ_H3, - SQ_A4, SQ_B4, SQ_C4, SQ_D4, SQ_E4, SQ_F4, SQ_G4, SQ_H4, - SQ_A5, SQ_B5, SQ_C5, SQ_D5, SQ_E5, SQ_F5, SQ_G5, SQ_H5, - SQ_A6, SQ_B6, SQ_C6, SQ_D6, SQ_E6, SQ_F6, SQ_G6, SQ_H6, - SQ_A7, SQ_B7, SQ_C7, SQ_D7, SQ_E7, SQ_F7, SQ_G7, SQ_H7, - SQ_A8, SQ_B8, SQ_C8, SQ_D8, SQ_E8, SQ_F8, SQ_G8, SQ_H8, - SQ_NONE, - - SQUARE_ZERO = 0, - SQUARE_NB = 64 -}; - -enum Direction : int { - NORTH = 8, - EAST = 1, - SOUTH = -NORTH, - WEST = -EAST, - - NORTH_EAST = NORTH + EAST, - SOUTH_EAST = SOUTH + EAST, - SOUTH_WEST = SOUTH + WEST, - NORTH_WEST = NORTH + WEST -}; - -enum File : int { - FILE_A, FILE_B, FILE_C, FILE_D, FILE_E, FILE_F, FILE_G, FILE_H, FILE_NB -}; - -enum Rank : int { - RANK_1, RANK_2, RANK_3, RANK_4, RANK_5, RANK_6, RANK_7, RANK_8, RANK_NB -}; - -// Keep track of what a move changes on the board (used by NNUE) -struct DirtyPiece { - - // Number of changed pieces - int dirty_num; - - // Max 3 pieces can change in one move. A promotion with capture moves - // both the pawn and the captured piece to SQ_NONE and the piece promoted - // to from SQ_NONE to the capture square. - Piece piece[3]; - - // From and to squares, which may be SQ_NONE - Square from[3]; - Square to[3]; -}; - -#define ENABLE_BASE_OPERATORS_ON(T) \ -constexpr T operator+(T d1, int d2) { return T(int(d1) + d2); } \ -constexpr T operator-(T d1, int d2) { return T(int(d1) - d2); } \ -constexpr T operator-(T d) { return T(-int(d)); } \ -inline T& operator+=(T& d1, int d2) { return d1 = d1 + d2; } \ -inline T& operator-=(T& d1, int d2) { return d1 = d1 - d2; } - -#define ENABLE_INCR_OPERATORS_ON(T) \ -inline T& operator++(T& d) { return d = T(int(d) + 1); } \ -inline T& operator--(T& d) { return d = T(int(d) - 1); } - -#define ENABLE_FULL_OPERATORS_ON(T) \ -ENABLE_BASE_OPERATORS_ON(T) \ -constexpr T operator*(int i, T d) { return T(i * int(d)); } \ -constexpr T operator*(T d, int i) { return T(int(d) * i); } \ -constexpr T operator/(T d, int i) { return T(int(d) / i); } \ -constexpr int operator/(T d1, T d2) { return int(d1) / int(d2); } \ -inline T& operator*=(T& d, int i) { return d = T(int(d) * i); } \ -inline T& operator/=(T& d, int i) { return d = T(int(d) / i); } - -ENABLE_FULL_OPERATORS_ON(Value) -ENABLE_FULL_OPERATORS_ON(Direction) - -ENABLE_INCR_OPERATORS_ON(PieceType) -ENABLE_INCR_OPERATORS_ON(Square) -ENABLE_INCR_OPERATORS_ON(File) -ENABLE_INCR_OPERATORS_ON(Rank) - -#undef ENABLE_FULL_OPERATORS_ON -#undef ENABLE_INCR_OPERATORS_ON -#undef ENABLE_BASE_OPERATORS_ON - -/// Additional operators to add a Direction to a Square -constexpr Square operator+(Square s, Direction d) { return Square(int(s) + int(d)); } -constexpr Square operator-(Square s, Direction d) { return Square(int(s) - int(d)); } -inline Square& operator+=(Square& s, Direction d) { return s = s + d; } -inline Square& operator-=(Square& s, Direction d) { return s = s - d; } - -constexpr Color operator~(Color c) { - return Color(c ^ BLACK); // Toggle color -} - -constexpr Square flip_rank(Square s) { // Swap A1 <-> A8 - return Square(s ^ SQ_A8); -} - -constexpr Square flip_file(Square s) { // Swap A1 <-> H1 - return Square(s ^ SQ_H1); -} - -constexpr Piece operator~(Piece pc) { - return Piece(pc ^ 8); // Swap color of piece B_KNIGHT <-> W_KNIGHT -} - -constexpr CastlingRights operator&(Color c, CastlingRights cr) { - return CastlingRights((c == WHITE ? WHITE_CASTLING : BLACK_CASTLING) & cr); -} - -constexpr Value mate_in(int ply) { - return VALUE_MATE - ply; -} - -constexpr Value mated_in(int ply) { - return -VALUE_MATE + ply; -} - -constexpr Square make_square(File f, Rank r) { - return Square((r << 3) + f); -} - -constexpr Piece make_piece(Color c, PieceType pt) { - return Piece((c << 3) + pt); -} - -constexpr PieceType type_of(Piece pc) { - return PieceType(pc & 7); -} - -inline Color color_of(Piece pc) { - assert(pc != NO_PIECE); - return Color(pc >> 3); -} - -constexpr bool is_ok(Move m) { - return m != MOVE_NONE && m != MOVE_NULL; -} - -constexpr bool is_ok(Square s) { - return s >= SQ_A1 && s <= SQ_H8; -} - -constexpr File file_of(Square s) { - return File(s & 7); -} - -constexpr Rank rank_of(Square s) { - return Rank(s >> 3); -} - -constexpr Square relative_square(Color c, Square s) { - return Square(s ^ (c * 56)); -} - -constexpr Rank relative_rank(Color c, Rank r) { - return Rank(r ^ (c * 7)); -} - -constexpr Rank relative_rank(Color c, Square s) { - return relative_rank(c, rank_of(s)); -} - -constexpr Direction pawn_push(Color c) { - return c == WHITE ? NORTH : SOUTH; -} - -constexpr Square from_sq(Move m) { - assert(is_ok(m)); - return Square((m >> 6) & 0x3F); -} + constexpr CastlingRights operator&(Color c, CastlingRights cr) { + return CastlingRights((c == WHITE ? WHITE_CASTLING : BLACK_CASTLING) & cr); + } + + constexpr Value mate_in(int ply) { return VALUE_MATE - ply; } + + constexpr Value mated_in(int ply) { return -VALUE_MATE + ply; } + + constexpr Square make_square(File f, Rank r) { return Square((r << 3) + f); } + + constexpr Piece make_piece(Color c, PieceType pt) { return Piece((c << 3) + pt); } + + constexpr PieceType type_of(Piece pc) { return PieceType(pc & 7); } + + inline Color color_of(Piece pc) { + assert(pc != NO_PIECE); + return Color(pc >> 3); + } + + constexpr bool is_ok(Move m) { return m != MOVE_NONE && m != MOVE_NULL; } + + constexpr bool is_ok(Square s) { return s >= SQ_A1 && s <= SQ_H8; } + + constexpr File file_of(Square s) { return File(s & 7); } + + constexpr Rank rank_of(Square s) { return Rank(s >> 3); } + + constexpr Square relative_square(Color c, Square s) { return Square(s ^ (c * 56)); } + + constexpr Rank relative_rank(Color c, Rank r) { return Rank(r ^ (c * 7)); } + + constexpr Rank relative_rank(Color c, Square s) { return relative_rank(c, rank_of(s)); } + + constexpr Direction pawn_push(Color c) { return c == WHITE ? NORTH : SOUTH; } + + constexpr Square from_sq(Move m) { + assert(is_ok(m)); + return Square((m >> 6) & 0x3F); + } + + constexpr Square to_sq(Move m) { + assert(is_ok(m)); + return Square(m & 0x3F); + } + + constexpr int from_to(Move m) { return m & 0xFFF; } + + constexpr MoveType type_of(Move m) { return MoveType(m & (3 << 14)); } -constexpr Square to_sq(Move m) { - assert(is_ok(m)); - return Square(m & 0x3F); -} - -constexpr int from_to(Move m) { - return m & 0xFFF; -} - -constexpr MoveType type_of(Move m) { - return MoveType(m & (3 << 14)); -} - -constexpr PieceType promotion_type(Move m) { - return PieceType(((m >> 12) & 3) + KNIGHT); -} + constexpr PieceType promotion_type(Move m) { return PieceType(((m >> 12) & 3) + KNIGHT); } -constexpr Move make_move(Square from, Square to) { - return Move((from << 6) + to); -} + constexpr Move make_move(Square from, Square to) { return Move((from << 6) + to); } -template -constexpr Move make(Square from, Square to, PieceType pt = KNIGHT) { - return Move(T + ((pt - KNIGHT) << 12) + (from << 6) + to); -} + template constexpr Move make(Square from, Square to, PieceType pt = KNIGHT) { + return Move(T + ((pt - KNIGHT) << 12) + (from << 6) + to); + } -/// Based on a congruential pseudo random number generator -constexpr Key make_key(uint64_t seed) { - return seed * 6364136223846793005ULL + 1442695040888963407ULL; -} + /// Based on a congruential pseudo random number generator + constexpr Key make_key(uint64_t seed) { + return seed * 6364136223846793005ULL + 1442695040888963407ULL; + } -} // namespace Stockfish +} // namespace Stockfish -#endif // #ifndef TYPES_H_INCLUDED +#endif // #ifndef TYPES_H_INCLUDED -#include "tune.h" // Global visibility to tuning setup +#include "tune.h" // Global visibility to tuning setup diff --git a/src/uci.cpp b/src/uci.cpp index f3e436ef3aa..575e730938d 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -43,370 +43,377 @@ namespace Stockfish { -namespace { + namespace { - // FEN string for the initial position in standard chess - const char* StartFEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"; + // FEN string for the initial position in standard chess + const char* StartFEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"; - // position() is called when the engine receives the "position" UCI command. - // It sets up the position that is described in the given FEN string ("fen") or - // the initial position ("startpos") and then makes the moves given in the following - // move list ("moves"). + // position() is called when the engine receives the "position" UCI command. + // It sets up the position that is described in the given FEN string ("fen") or + // the initial position ("startpos") and then makes the moves given in the following + // move list ("moves"). - void position(Position& pos, std::istringstream& is, StateListPtr& states) { + void position(Position& pos, std::istringstream& is, StateListPtr& states) { - Move m; - std::string token, fen; + Move m; + std::string token, fen; - is >> token; + is >> token; - if (token == "startpos") - { - fen = StartFEN; - is >> token; // Consume the "moves" token, if any - } - else if (token == "fen") - while (is >> token && token != "moves") - fen += token + " "; - else - return; - - states = StateListPtr(new std::deque(1)); // Drop the old state and create a new one - pos.set(fen, Options["UCI_Chess960"], &states->back(), Threads.main()); - - // Parse the move list, if any - while (is >> token && (m = UCI::to_move(pos, token)) != MOVE_NONE) - { - states->emplace_back(); - pos.do_move(m, states->back()); - } - } + if (token == "startpos") { + fen = StartFEN; + is >> token; // Consume the "moves" token, if any + } else if (token == "fen") + while (is >> token && token != "moves") fen += token + " "; + else + return; - // trace_eval() prints the evaluation of the current position, consistent with - // the UCI options set so far. + states = StateListPtr( + new std::deque(1)); // Drop the old state and create a new one + pos.set(fen, Options["UCI_Chess960"], &states->back(), Threads.main()); - void trace_eval(Position& pos) { + // Parse the move list, if any + while (is >> token && (m = UCI::to_move(pos, token)) != MOVE_NONE) { + states->emplace_back(); + pos.do_move(m, states->back()); + } + } - StateListPtr states(new std::deque(1)); - Position p; - p.set(pos.fen(), Options["UCI_Chess960"], &states->back(), Threads.main()); + // trace_eval() prints the evaluation of the current position, consistent with + // the UCI options set so far. - Eval::NNUE::verify(); + void trace_eval(Position& pos) { - sync_cout << "\n" << Eval::trace(p) << sync_endl; - } + StateListPtr states(new std::deque(1)); + Position p; + p.set(pos.fen(), Options["UCI_Chess960"], &states->back(), Threads.main()); + Eval::NNUE::verify(); - // setoption() is called when the engine receives the "setoption" UCI command. - // The function updates the UCI option ("name") to the given value ("value"). + sync_cout << "\n" << Eval::trace(p) << sync_endl; + } - void setoption(std::istringstream& is) { - Threads.main()->wait_for_search_finished(); + // setoption() is called when the engine receives the "setoption" UCI command. + // The function updates the UCI option ("name") to the given value ("value"). - std::string token, name, value; + void setoption(std::istringstream& is) { - is >> token; // Consume the "name" token + Threads.main()->wait_for_search_finished(); - // Read the option name (can contain spaces) - while (is >> token && token != "value") - name += (name.empty() ? "" : " ") + token; + std::string token, name, value; - // Read the option value (can contain spaces) - while (is >> token) - value += (value.empty() ? "" : " ") + token; + is >> token; // Consume the "name" token - if (Options.count(name)) - Options[name] = value; - else - sync_cout << "No such option: " << name << sync_endl; - } + // Read the option name (can contain spaces) + while (is >> token && token != "value") name += (name.empty() ? "" : " ") + token; + // Read the option value (can contain spaces) + while (is >> token) value += (value.empty() ? "" : " ") + token; - // go() is called when the engine receives the "go" UCI command. The function - // sets the thinking time and other parameters from the input string, then starts - // with a search. + if (Options.count(name)) + Options[name] = value; + else + sync_cout << "No such option: " << name << sync_endl; + } - void go(Position& pos, std::istringstream& is, StateListPtr& states) { - Search::LimitsType limits; - std::string token; - bool ponderMode = false; + // go() is called when the engine receives the "go" UCI command. The function + // sets the thinking time and other parameters from the input string, then starts + // with a search. - limits.startTime = now(); // The search starts as early as possible + void go(Position& pos, std::istringstream& is, StateListPtr& states) { - while (is >> token) - if (token == "searchmoves") // Needs to be the last command on the line - while (is >> token) - limits.searchmoves.push_back(UCI::to_move(pos, token)); - - else if (token == "wtime") is >> limits.time[WHITE]; - else if (token == "btime") is >> limits.time[BLACK]; - else if (token == "winc") is >> limits.inc[WHITE]; - else if (token == "binc") is >> limits.inc[BLACK]; - else if (token == "movestogo") is >> limits.movestogo; - else if (token == "depth") is >> limits.depth; - else if (token == "nodes") is >> limits.nodes; - else if (token == "movetime") is >> limits.movetime; - else if (token == "mate") is >> limits.mate; - else if (token == "perft") is >> limits.perft; - else if (token == "infinite") limits.infinite = 1; - else if (token == "ponder") ponderMode = true; - - Threads.start_thinking(pos, states, limits, ponderMode); - } - - - // bench() is called when the engine receives the "bench" command. - // Firstly, a list of UCI commands is set up according to the bench - // parameters, then it is run one by one, printing a summary at the end. - - void bench(Position& pos, std::istream& args, StateListPtr& states) { - - std::string token; - uint64_t num, nodes = 0, cnt = 1; - - std::vector list = setup_bench(pos, args); - num = count_if(list.begin(), list.end(), [](const std::string& s) { return s.find("go ") == 0 || s.find("eval") == 0; }); - - TimePoint elapsed = now(); - - for (const auto& cmd : list) - { - std::istringstream is(cmd); - is >> std::skipws >> token; - - if (token == "go" || token == "eval") - { - std::cerr << "\nPosition: " << cnt++ << '/' << num << " (" << pos.fen() << ")" << std::endl; - if (token == "go") - { - go(pos, is, states); - Threads.main()->wait_for_search_finished(); - nodes += Threads.nodes_searched(); - } - else - trace_eval(pos); - } - else if (token == "setoption") setoption(is); - else if (token == "position") position(pos, is, states); - else if (token == "ucinewgame") { Search::clear(); elapsed = now(); } // Search::clear() may take a while - } + Search::LimitsType limits; + std::string token; + bool ponderMode = false; - elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero' + limits.startTime = now(); // The search starts as early as possible - dbg_print(); - - std::cerr << "\n===========================" - << "\nTotal time (ms) : " << elapsed - << "\nNodes searched : " << nodes - << "\nNodes/second : " << 1000 * nodes / elapsed << std::endl; - } - - // The win rate model returns the probability of winning (in per mille units) given an - // eval and a game ply. It fits the LTC fishtest statistics rather accurately. - int win_rate_model(Value v, int ply) { - - // The model only captures up to 240 plies, so limit the input and then rescale - double m = std::min(240, ply) / 64.0; - - // The coefficients of a third-order polynomial fit is based on the fishtest data - // for two parameters that need to transform eval to the argument of a logistic - // function. - constexpr double as[] = { 0.38036525, -2.82015070, 23.17882135, 307.36768407}; - constexpr double bs[] = { -2.29434733, 13.27689788, -14.26828904, 63.45318330 }; - - // Enforce that NormalizeToPawnValue corresponds to a 50% win rate at ply 64 - static_assert(UCI::NormalizeToPawnValue == int(as[0] + as[1] + as[2] + as[3])); - - double a = (((as[0] * m + as[1]) * m + as[2]) * m) + as[3]; - double b = (((bs[0] * m + bs[1]) * m + bs[2]) * m) + bs[3]; - - // Transform the eval to centipawns with limited range - double x = std::clamp(double(v), -4000.0, 4000.0); - - // Return the win rate in per mille units rounded to the nearest value - return int(0.5 + 1000 / (1 + std::exp((a - x) / b))); - } + while (is >> token) + if (token == "searchmoves") // Needs to be the last command on the line + while (is >> token) limits.searchmoves.push_back(UCI::to_move(pos, token)); + + else if (token == "wtime") + is >> limits.time[WHITE]; + else if (token == "btime") + is >> limits.time[BLACK]; + else if (token == "winc") + is >> limits.inc[WHITE]; + else if (token == "binc") + is >> limits.inc[BLACK]; + else if (token == "movestogo") + is >> limits.movestogo; + else if (token == "depth") + is >> limits.depth; + else if (token == "nodes") + is >> limits.nodes; + else if (token == "movetime") + is >> limits.movetime; + else if (token == "mate") + is >> limits.mate; + else if (token == "perft") + is >> limits.perft; + else if (token == "infinite") + limits.infinite = 1; + else if (token == "ponder") + ponderMode = true; + + Threads.start_thinking(pos, states, limits, ponderMode); + } -} // namespace + // bench() is called when the engine receives the "bench" command. + // Firstly, a list of UCI commands is set up according to the bench + // parameters, then it is run one by one, printing a summary at the end. + + void bench(Position& pos, std::istream& args, StateListPtr& states) { + + std::string token; + uint64_t num, nodes = 0, cnt = 1; + + std::vector list = setup_bench(pos, args); + num = count_if(list.begin(), list.end(), [](const std::string& s) { + return s.find("go ") == 0 || s.find("eval") == 0; + }); + + TimePoint elapsed = now(); + + for (const auto& cmd : list) { + std::istringstream is(cmd); + is >> std::skipws >> token; + + if (token == "go" || token == "eval") { + std::cerr << "\nPosition: " << cnt++ << '/' << num << " (" << pos.fen() << ")" + << std::endl; + if (token == "go") { + go(pos, is, states); + Threads.main()->wait_for_search_finished(); + nodes += Threads.nodes_searched(); + } else + trace_eval(pos); + } else if (token == "setoption") + setoption(is); + else if (token == "position") + position(pos, is, states); + else if (token == "ucinewgame") { + Search::clear(); + elapsed = now(); + } // Search::clear() may take a while + } -/// UCI::loop() waits for a command from the stdin, parses it and then calls the appropriate -/// function. It also intercepts an end-of-file (EOF) indication from the stdin to ensure a -/// graceful exit if the GUI dies unexpectedly. When called with some command-line arguments, -/// like running 'bench', the function returns immediately after the command is executed. -/// In addition to the UCI ones, some additional debug commands are also supported. - -void UCI::loop(int argc, char* argv[]) { - - Position pos; - std::string token, cmd; - StateListPtr states(new std::deque(1)); - - pos.set(StartFEN, false, &states->back(), Threads.main()); - - for (int i = 1; i < argc; ++i) - cmd += std::string(argv[i]) + " "; + elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero' - do { - if (argc == 1 && !getline(std::cin, cmd)) // Wait for an input or an end-of-file (EOF) indication - cmd = "quit"; + dbg_print(); - std::istringstream is(cmd); + std::cerr << "\n===========================" + << "\nTotal time (ms) : " << elapsed << "\nNodes searched : " << nodes + << "\nNodes/second : " << 1000 * nodes / elapsed << std::endl; + } - token.clear(); // Avoid a stale if getline() returns nothing or a blank line - is >> std::skipws >> token; + // The win rate model returns the probability of winning (in per mille units) given an + // eval and a game ply. It fits the LTC fishtest statistics rather accurately. + int win_rate_model(Value v, int ply) { - if ( token == "quit" - || token == "stop") - Threads.stop = true; + // The model only captures up to 240 plies, so limit the input and then rescale + double m = std::min(240, ply) / 64.0; - // The GUI sends 'ponderhit' to tell that the user has played the expected move. - // So, 'ponderhit' is sent if pondering was done on the same move that the user - // has played. The search should continue, but should also switch from pondering - // to the normal search. - else if (token == "ponderhit") - Threads.main()->ponder = false; // Switch to the normal search + // The coefficients of a third-order polynomial fit is based on the fishtest data + // for two parameters that need to transform eval to the argument of a logistic + // function. + constexpr double as[] = {0.38036525, -2.82015070, 23.17882135, 307.36768407}; + constexpr double bs[] = {-2.29434733, 13.27689788, -14.26828904, 63.45318330}; - else if (token == "uci") - sync_cout << "id name " << engine_info(true) - << "\n" << Options - << "\nuciok" << sync_endl; + // Enforce that NormalizeToPawnValue corresponds to a 50% win rate at ply 64 + static_assert(UCI::NormalizeToPawnValue == int(as[0] + as[1] + as[2] + as[3])); - else if (token == "setoption") setoption(is); - else if (token == "go") go(pos, is, states); - else if (token == "position") position(pos, is, states); - else if (token == "ucinewgame") Search::clear(); - else if (token == "isready") sync_cout << "readyok" << sync_endl; + double a = (((as[0] * m + as[1]) * m + as[2]) * m) + as[3]; + double b = (((bs[0] * m + bs[1]) * m + bs[2]) * m) + bs[3]; - // Add custom non-UCI commands, mainly for debugging purposes. - // These commands must not be used during a search! - else if (token == "flip") pos.flip(); - else if (token == "bench") bench(pos, is, states); - else if (token == "d") sync_cout << pos << sync_endl; - else if (token == "eval") trace_eval(pos); - else if (token == "compiler") sync_cout << compiler_info() << sync_endl; - else if (token == "export_net") - { - std::optional filename; - std::string f; - if (is >> std::skipws >> f) - filename = f; - Eval::NNUE::save_eval(filename); - } - else if (token == "--help" || token == "help" || token == "--license" || token == "license") - sync_cout << "\nStockfish is a powerful chess engine for playing and analyzing." - "\nIt is released as free software licensed under the GNU GPLv3 License." - "\nStockfish is normally used with a graphical user interface (GUI) and implements" - "\nthe Universal Chess Interface (UCI) protocol to communicate with a GUI, an API, etc." - "\nFor any further information, visit https://github.com/official-stockfish/Stockfish#readme" - "\nor read the corresponding README.md and Copying.txt files distributed along with this program.\n" << sync_endl; - else if (!token.empty() && token[0] != '#') - sync_cout << "Unknown command: '" << cmd << "'. Type help for more information." << sync_endl; + // Transform the eval to centipawns with limited range + double x = std::clamp(double(v), -4000.0, 4000.0); - } while (token != "quit" && argc == 1); // The command-line arguments are one-shot -} + // Return the win rate in per mille units rounded to the nearest value + return int(0.5 + 1000 / (1 + std::exp((a - x) / b))); + } + } // namespace + + + /// UCI::loop() waits for a command from the stdin, parses it and then calls the appropriate + /// function. It also intercepts an end-of-file (EOF) indication from the stdin to ensure a + /// graceful exit if the GUI dies unexpectedly. When called with some command-line arguments, + /// like running 'bench', the function returns immediately after the command is executed. + /// In addition to the UCI ones, some additional debug commands are also supported. + + void UCI::loop(int argc, char* argv[]) { + + Position pos; + std::string token, cmd; + StateListPtr states(new std::deque(1)); + + pos.set(StartFEN, false, &states->back(), Threads.main()); + + for (int i = 1; i < argc; ++i) cmd += std::string(argv[i]) + " "; + + do { + if (argc == 1 && + !getline(std::cin, cmd)) // Wait for an input or an end-of-file (EOF) indication + cmd = "quit"; + + std::istringstream is(cmd); + + token.clear(); // Avoid a stale if getline() returns nothing or a blank line + is >> std::skipws >> token; + + if (token == "quit" || token == "stop") Threads.stop = true; + + // The GUI sends 'ponderhit' to tell that the user has played the expected move. + // So, 'ponderhit' is sent if pondering was done on the same move that the user + // has played. The search should continue, but should also switch from pondering + // to the normal search. + else if (token == "ponderhit") + Threads.main()->ponder = false; // Switch to the normal search + + else if (token == "uci") + sync_cout << "id name " << engine_info(true) << "\n" + << Options << "\nuciok" << sync_endl; + + else if (token == "setoption") + setoption(is); + else if (token == "go") + go(pos, is, states); + else if (token == "position") + position(pos, is, states); + else if (token == "ucinewgame") + Search::clear(); + else if (token == "isready") + sync_cout << "readyok" << sync_endl; + + // Add custom non-UCI commands, mainly for debugging purposes. + // These commands must not be used during a search! + else if (token == "flip") + pos.flip(); + else if (token == "bench") + bench(pos, is, states); + else if (token == "d") + sync_cout << pos << sync_endl; + else if (token == "eval") + trace_eval(pos); + else if (token == "compiler") + sync_cout << compiler_info() << sync_endl; + else if (token == "export_net") { + std::optional filename; + std::string f; + if (is >> std::skipws >> f) filename = f; + Eval::NNUE::save_eval(filename); + } else if (token == "--help" || token == "help" || token == "--license" || + token == "license") + sync_cout + << "\nStockfish is a powerful chess engine for playing and analyzing." + "\nIt is released as free software licensed under the GNU GPLv3 License." + "\nStockfish is normally used with a graphical user interface (GUI) and implements" + "\nthe Universal Chess Interface (UCI) protocol to communicate with a GUI, an API, etc." + "\nFor any further information, visit https://github.com/official-stockfish/Stockfish#readme" + "\nor read the corresponding README.md and Copying.txt files distributed along with this program.\n" + << sync_endl; + else if (!token.empty() && token[0] != '#') + sync_cout << "Unknown command: '" << cmd << "'. Type help for more information." + << sync_endl; + + } while (token != "quit" && argc == 1); // The command-line arguments are one-shot + } -/// Turns a Value to an integer centipawn number, -/// without treatment of mate and similar special scores. -int UCI::to_cp(Value v) { - return 100 * v / UCI::NormalizeToPawnValue; -} + /// Turns a Value to an integer centipawn number, + /// without treatment of mate and similar special scores. + int UCI::to_cp(Value v) { return 100 * v / UCI::NormalizeToPawnValue; } -/// UCI::value() converts a Value to a string by adhering to the UCI protocol specification: -/// -/// cp The score from the engine's point of view in centipawns. -/// mate Mate in 'y' moves (not plies). If the engine is getting mated, -/// uses negative values for 'y'. + /// UCI::value() converts a Value to a string by adhering to the UCI protocol specification: + /// + /// cp The score from the engine's point of view in centipawns. + /// mate Mate in 'y' moves (not plies). If the engine is getting mated, + /// uses negative values for 'y'. -std::string UCI::value(Value v) { + std::string UCI::value(Value v) { - assert(-VALUE_INFINITE < v && v < VALUE_INFINITE); + assert(-VALUE_INFINITE < v && v < VALUE_INFINITE); - std::stringstream ss; + std::stringstream ss; - if (abs(v) < VALUE_TB_WIN_IN_MAX_PLY) - ss << "cp " << UCI::to_cp(v); - else if (abs(v) < VALUE_MATE_IN_MAX_PLY) - { - const int ply = VALUE_MATE_IN_MAX_PLY - 1 - std::abs(v); // recompute ss->ply - ss << "cp " << (v > 0 ? 20000 - ply : -20000 + ply); - } - else - ss << "mate " << (v > 0 ? VALUE_MATE - v + 1 : -VALUE_MATE - v) / 2; + if (abs(v) < VALUE_TB_WIN_IN_MAX_PLY) + ss << "cp " << UCI::to_cp(v); + else if (abs(v) < VALUE_MATE_IN_MAX_PLY) { + const int ply = VALUE_MATE_IN_MAX_PLY - 1 - std::abs(v); // recompute ss->ply + ss << "cp " << (v > 0 ? 20000 - ply : -20000 + ply); + } else + ss << "mate " << (v > 0 ? VALUE_MATE - v + 1 : -VALUE_MATE - v) / 2; - return ss.str(); -} + return ss.str(); + } -/// UCI::wdl() reports the win-draw-loss (WDL) statistics given an evaluation -/// and a game ply based on the data gathered for fishtest LTC games. + /// UCI::wdl() reports the win-draw-loss (WDL) statistics given an evaluation + /// and a game ply based on the data gathered for fishtest LTC games. -std::string UCI::wdl(Value v, int ply) { + std::string UCI::wdl(Value v, int ply) { - std::stringstream ss; + std::stringstream ss; - int wdl_w = win_rate_model( v, ply); - int wdl_l = win_rate_model(-v, ply); - int wdl_d = 1000 - wdl_w - wdl_l; - ss << " wdl " << wdl_w << " " << wdl_d << " " << wdl_l; + int wdl_w = win_rate_model(v, ply); + int wdl_l = win_rate_model(-v, ply); + int wdl_d = 1000 - wdl_w - wdl_l; + ss << " wdl " << wdl_w << " " << wdl_d << " " << wdl_l; - return ss.str(); -} + return ss.str(); + } -/// UCI::square() converts a Square to a string in algebraic notation (g1, a7, etc.) + /// UCI::square() converts a Square to a string in algebraic notation (g1, a7, etc.) -std::string UCI::square(Square s) { - return std::string{ char('a' + file_of(s)), char('1' + rank_of(s)) }; -} + std::string UCI::square(Square s) { + return std::string{char('a' + file_of(s)), char('1' + rank_of(s))}; + } -/// UCI::move() converts a Move to a string in coordinate notation (g1f3, a7a8q). -/// The only special case is castling where the e1g1 notation is printed in -/// standard chess mode and in e1h1 notation it is printed in Chess960 mode. -/// Internally, all castling moves are always encoded as 'king captures rook'. + /// UCI::move() converts a Move to a string in coordinate notation (g1f3, a7a8q). + /// The only special case is castling where the e1g1 notation is printed in + /// standard chess mode and in e1h1 notation it is printed in Chess960 mode. + /// Internally, all castling moves are always encoded as 'king captures rook'. -std::string UCI::move(Move m, bool chess960) { + std::string UCI::move(Move m, bool chess960) { - if (m == MOVE_NONE) - return "(none)"; + if (m == MOVE_NONE) return "(none)"; - if (m == MOVE_NULL) - return "0000"; + if (m == MOVE_NULL) return "0000"; - Square from = from_sq(m); - Square to = to_sq(m); + Square from = from_sq(m); + Square to = to_sq(m); - if (type_of(m) == CASTLING && !chess960) - to = make_square(to > from ? FILE_G : FILE_C, rank_of(from)); + if (type_of(m) == CASTLING && !chess960) + to = make_square(to > from ? FILE_G : FILE_C, rank_of(from)); - std::string move = UCI::square(from) + UCI::square(to); + std::string move = UCI::square(from) + UCI::square(to); - if (type_of(m) == PROMOTION) - move += " pnbrqk"[promotion_type(m)]; + if (type_of(m) == PROMOTION) move += " pnbrqk"[promotion_type(m)]; - return move; -} + return move; + } -/// UCI::to_move() converts a string representing a move in coordinate notation -/// (g1f3, a7a8q) to the corresponding legal Move, if any. + /// UCI::to_move() converts a string representing a move in coordinate notation + /// (g1f3, a7a8q) to the corresponding legal Move, if any. -Move UCI::to_move(const Position& pos, std::string& str) { + Move UCI::to_move(const Position& pos, std::string& str) { - if (str.length() == 5) - str[4] = char(tolower(str[4])); // The promotion piece character must be lowercased + if (str.length() == 5) + str[4] = char(tolower(str[4])); // The promotion piece character must be lowercased - for (const auto& m : MoveList(pos)) - if (str == UCI::move(m, pos.is_chess960())) - return m; + for (const auto& m : MoveList(pos)) + if (str == UCI::move(m, pos.is_chess960())) return m; - return MOVE_NONE; -} + return MOVE_NONE; + } -} // namespace Stockfish +} // namespace Stockfish diff --git a/src/uci.h b/src/uci.h index 7ca97d5c6bc..342a82623ff 100644 --- a/src/uci.h +++ b/src/uci.h @@ -28,68 +28,68 @@ namespace Stockfish { -class Position; + class Position; -namespace UCI { + namespace UCI { -// Normalizes the internal value as reported by evaluate or search -// to the UCI centipawn result used in output. This value is derived from -// the win_rate_model() such that Stockfish outputs an advantage of -// "100 centipawns" for a position if the engine has a 50% probability to win -// from this position in selfplay at fishtest LTC time control. -const int NormalizeToPawnValue = 328; + // Normalizes the internal value as reported by evaluate or search + // to the UCI centipawn result used in output. This value is derived from + // the win_rate_model() such that Stockfish outputs an advantage of + // "100 centipawns" for a position if the engine has a 50% probability to win + // from this position in selfplay at fishtest LTC time control. + const int NormalizeToPawnValue = 328; -class Option; + class Option; -/// Define a custom comparator, because the UCI options should be case-insensitive -struct CaseInsensitiveLess { - bool operator() (const std::string&, const std::string&) const; -}; + /// Define a custom comparator, because the UCI options should be case-insensitive + struct CaseInsensitiveLess { + bool operator()(const std::string&, const std::string&) const; + }; -/// The options container is defined as a std::map -using OptionsMap = std::map; + /// The options container is defined as a std::map + using OptionsMap = std::map; -/// The Option class implements each option as specified by the UCI protocol -class Option { + /// The Option class implements each option as specified by the UCI protocol + class Option { - using OnChange = void (*)(const Option&); + using OnChange = void (*)(const Option&); -public: - Option(OnChange = nullptr); - Option(bool v, OnChange = nullptr); - Option(const char* v, OnChange = nullptr); - Option(double v, int minv, int maxv, OnChange = nullptr); - Option(const char* v, const char* cur, OnChange = nullptr); + public: + Option(OnChange = nullptr); + Option(bool v, OnChange = nullptr); + Option(const char* v, OnChange = nullptr); + Option(double v, int minv, int maxv, OnChange = nullptr); + Option(const char* v, const char* cur, OnChange = nullptr); - Option& operator=(const std::string&); - void operator<<(const Option&); - operator int() const; - operator std::string() const; - bool operator==(const char*) const; + Option& operator=(const std::string&); + void operator<<(const Option&); + operator int() const; + operator std::string() const; + bool operator==(const char*) const; -private: - friend std::ostream& operator<<(std::ostream&, const OptionsMap&); + private: + friend std::ostream& operator<<(std::ostream&, const OptionsMap&); - std::string defaultValue, currentValue, type; - int min, max; - size_t idx; - OnChange on_change; -}; + std::string defaultValue, currentValue, type; + int min, max; + size_t idx; + OnChange on_change; + }; -void init(OptionsMap&); -void loop(int argc, char* argv[]); -int to_cp(Value v); -std::string value(Value v); -std::string square(Square s); -std::string move(Move m, bool chess960); -std::string pv(const Position& pos, Depth depth); -std::string wdl(Value v, int ply); -Move to_move(const Position& pos, std::string& str); + void init(OptionsMap&); + void loop(int argc, char* argv[]); + int to_cp(Value v); + std::string value(Value v); + std::string square(Square s); + std::string move(Move m, bool chess960); + std::string pv(const Position& pos, Depth depth); + std::string wdl(Value v, int ply); + Move to_move(const Position& pos, std::string& str); -} // namespace UCI + } // namespace UCI -extern UCI::OptionsMap Options; + extern UCI::OptionsMap Options; -} // namespace Stockfish +} // namespace Stockfish -#endif // #ifndef UCI_H_INCLUDED +#endif // #ifndef UCI_H_INCLUDED diff --git a/src/ucioption.cpp b/src/ucioption.cpp index 8d2c5c098ed..9fb4b0ebf0e 100644 --- a/src/ucioption.cpp +++ b/src/ucioption.cpp @@ -40,160 +40,160 @@ using std::string; namespace Stockfish { -UCI::OptionsMap Options; // Global object + UCI::OptionsMap Options; // Global object -namespace UCI { + namespace UCI { -/// 'On change' actions, triggered by an option's value change -static void on_clear_hash(const Option&) { Search::clear(); } -static void on_hash_size(const Option& o) { TT.resize(size_t(o)); } -static void on_logger(const Option& o) { start_logger(o); } -static void on_threads(const Option& o) { Threads.set(size_t(o)); } -static void on_tb_path(const Option& o) { Tablebases::init(o); } -static void on_eval_file(const Option&) { Eval::NNUE::init(); } + /// 'On change' actions, triggered by an option's value change + static void on_clear_hash(const Option&) { Search::clear(); } + static void on_hash_size(const Option& o) { TT.resize(size_t(o)); } + static void on_logger(const Option& o) { start_logger(o); } + static void on_threads(const Option& o) { Threads.set(size_t(o)); } + static void on_tb_path(const Option& o) { Tablebases::init(o); } + static void on_eval_file(const Option&) { Eval::NNUE::init(); } + + /// Our case insensitive less() function as required by UCI protocol + bool CaseInsensitiveLess::operator()(const string& s1, const string& s2) const { -/// Our case insensitive less() function as required by UCI protocol -bool CaseInsensitiveLess::operator() (const string& s1, const string& s2) const { + return std::lexicographical_compare( + s1.begin(), s1.end(), s2.begin(), s2.end(), + [](char c1, char c2) { return tolower(c1) < tolower(c2); }); + } - return std::lexicographical_compare(s1.begin(), s1.end(), s2.begin(), s2.end(), - [](char c1, char c2) { return tolower(c1) < tolower(c2); }); -} + /// UCI::init() initializes the UCI options to their hard-coded default values -/// UCI::init() initializes the UCI options to their hard-coded default values + void init(OptionsMap& o) { -void init(OptionsMap& o) { + constexpr int MaxHashMB = Is64Bit ? 33554432 : 2048; - constexpr int MaxHashMB = Is64Bit ? 33554432 : 2048; + o["Debug Log File"] << Option("", on_logger); + o["Threads"] << Option(1, 1, 1024, on_threads); + o["Hash"] << Option(16, 1, MaxHashMB, on_hash_size); + o["Clear Hash"] << Option(on_clear_hash); + o["Ponder"] << Option(false); + o["MultiPV"] << Option(1, 1, 500); + o["Skill Level"] << Option(20, 0, 20); + o["Move Overhead"] << Option(10, 0, 5000); + o["Slow Mover"] << Option(100, 10, 1000); + o["nodestime"] << Option(0, 0, 10000); + o["UCI_Chess960"] << Option(false); + o["UCI_AnalyseMode"] << Option(false); + o["UCI_LimitStrength"] << Option(false); + o["UCI_Elo"] << Option(1320, 1320, 3190); + o["UCI_ShowWDL"] << Option(false); + o["SyzygyPath"] << Option("", on_tb_path); + o["SyzygyProbeDepth"] << Option(1, 1, 100); + o["Syzygy50MoveRule"] << Option(true); + o["SyzygyProbeLimit"] << Option(7, 0, 7); + o["EvalFile"] << Option(EvalFileDefaultName, on_eval_file); + } - o["Debug Log File"] << Option("", on_logger); - o["Threads"] << Option(1, 1, 1024, on_threads); - o["Hash"] << Option(16, 1, MaxHashMB, on_hash_size); - o["Clear Hash"] << Option(on_clear_hash); - o["Ponder"] << Option(false); - o["MultiPV"] << Option(1, 1, 500); - o["Skill Level"] << Option(20, 0, 20); - o["Move Overhead"] << Option(10, 0, 5000); - o["Slow Mover"] << Option(100, 10, 1000); - o["nodestime"] << Option(0, 0, 10000); - o["UCI_Chess960"] << Option(false); - o["UCI_AnalyseMode"] << Option(false); - o["UCI_LimitStrength"] << Option(false); - o["UCI_Elo"] << Option(1320, 1320, 3190); - o["UCI_ShowWDL"] << Option(false); - o["SyzygyPath"] << Option("", on_tb_path); - o["SyzygyProbeDepth"] << Option(1, 1, 100); - o["Syzygy50MoveRule"] << Option(true); - o["SyzygyProbeLimit"] << Option(7, 0, 7); - o["EvalFile"] << Option(EvalFileDefaultName, on_eval_file); -} + /// operator<<() is used to print all the options default values in chronological + /// insertion order (the idx field) and in the format defined by the UCI protocol. -/// operator<<() is used to print all the options default values in chronological -/// insertion order (the idx field) and in the format defined by the UCI protocol. + std::ostream& operator<<(std::ostream& os, const OptionsMap& om) { -std::ostream& operator<<(std::ostream& os, const OptionsMap& om) { + for (size_t idx = 0; idx < om.size(); ++idx) + for (const auto& it : om) + if (it.second.idx == idx) { + const Option& o = it.second; + os << "\noption name " << it.first << " type " << o.type; - for (size_t idx = 0; idx < om.size(); ++idx) - for (const auto& it : om) - if (it.second.idx == idx) - { - const Option& o = it.second; - os << "\noption name " << it.first << " type " << o.type; + if (o.type == "string" || o.type == "check" || o.type == "combo") + os << " default " << o.defaultValue; - if (o.type == "string" || o.type == "check" || o.type == "combo") - os << " default " << o.defaultValue; + if (o.type == "spin") + os << " default " << int(stof(o.defaultValue)) << " min " << o.min + << " max " << o.max; - if (o.type == "spin") - os << " default " << int(stof(o.defaultValue)) - << " min " << o.min - << " max " << o.max; + break; + } - break; - } + return os; + } - return os; -} + /// Option class constructors and conversion operators -/// Option class constructors and conversion operators + Option::Option(const char* v, OnChange f) : type("string"), min(0), max(0), on_change(f) { + defaultValue = currentValue = v; + } -Option::Option(const char* v, OnChange f) : type("string"), min(0), max(0), on_change(f) -{ defaultValue = currentValue = v; } + Option::Option(bool v, OnChange f) : type("check"), min(0), max(0), on_change(f) { + defaultValue = currentValue = (v ? "true" : "false"); + } -Option::Option(bool v, OnChange f) : type("check"), min(0), max(0), on_change(f) -{ defaultValue = currentValue = (v ? "true" : "false"); } + Option::Option(OnChange f) : type("button"), min(0), max(0), on_change(f) {} -Option::Option(OnChange f) : type("button"), min(0), max(0), on_change(f) -{} + Option::Option(double v, int minv, int maxv, OnChange f) + : type("spin"), min(minv), max(maxv), on_change(f) { + defaultValue = currentValue = std::to_string(v); + } -Option::Option(double v, int minv, int maxv, OnChange f) : type("spin"), min(minv), max(maxv), on_change(f) -{ defaultValue = currentValue = std::to_string(v); } + Option::Option(const char* v, const char* cur, OnChange f) + : type("combo"), min(0), max(0), on_change(f) { + defaultValue = v; + currentValue = cur; + } -Option::Option(const char* v, const char* cur, OnChange f) : type("combo"), min(0), max(0), on_change(f) -{ defaultValue = v; currentValue = cur; } + Option::operator int() const { + assert(type == "check" || type == "spin"); + return (type == "spin" ? std::stoi(currentValue) : currentValue == "true"); + } -Option::operator int() const { - assert(type == "check" || type == "spin"); - return (type == "spin" ? std::stoi(currentValue) : currentValue == "true"); -} + Option::operator std::string() const { + assert(type == "string"); + return currentValue; + } -Option::operator std::string() const { - assert(type == "string"); - return currentValue; -} + bool Option::operator==(const char* s) const { + assert(type == "combo"); + return !CaseInsensitiveLess()(currentValue, s) && + !CaseInsensitiveLess()(s, currentValue); + } -bool Option::operator==(const char* s) const { - assert(type == "combo"); - return !CaseInsensitiveLess()(currentValue, s) - && !CaseInsensitiveLess()(s, currentValue); -} + /// operator<<() inits options and assigns idx in the correct printing order -/// operator<<() inits options and assigns idx in the correct printing order + void Option::operator<<(const Option& o) { -void Option::operator<<(const Option& o) { + static size_t insert_order = 0; - static size_t insert_order = 0; + *this = o; + idx = insert_order++; + } - *this = o; - idx = insert_order++; -} + /// operator=() updates currentValue and triggers on_change() action. It's up to + /// the GUI to check for option's limits, but we could receive the new value + /// from the user by console window, so let's check the bounds anyway. -/// operator=() updates currentValue and triggers on_change() action. It's up to -/// the GUI to check for option's limits, but we could receive the new value -/// from the user by console window, so let's check the bounds anyway. + Option& Option::operator=(const string& v) { -Option& Option::operator=(const string& v) { + assert(!type.empty()); - assert(!type.empty()); + if ((type != "button" && type != "string" && v.empty()) || + (type == "check" && v != "true" && v != "false") || + (type == "spin" && (stof(v) < min || stof(v) > max))) + return *this; - if ( (type != "button" && type != "string" && v.empty()) - || (type == "check" && v != "true" && v != "false") - || (type == "spin" && (stof(v) < min || stof(v) > max))) - return *this; + if (type == "combo") { + OptionsMap comboMap; // To have case insensitive compare + string token; + std::istringstream ss(defaultValue); + while (ss >> token) comboMap[token] << Option(); + if (!comboMap.count(v) || v == "var") return *this; + } - if (type == "combo") - { - OptionsMap comboMap; // To have case insensitive compare - string token; - std::istringstream ss(defaultValue); - while (ss >> token) - comboMap[token] << Option(); - if (!comboMap.count(v) || v == "var") - return *this; - } + if (type != "button") currentValue = v; - if (type != "button") - currentValue = v; + if (on_change) on_change(*this); - if (on_change) - on_change(*this); + return *this; + } - return *this; -} + } // namespace UCI -} // namespace UCI - -} // namespace Stockfish +} // namespace Stockfish