Skip to content

Commit

Permalink
Adjust LMR with correction history
Browse files Browse the repository at this point in the history
bench 1079115
  • Loading branch information
xu-shawn committed Dec 18, 2024
1 parent cf10644 commit 397b2f3
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,26 @@ constexpr int futility_move_count(bool improving, Depth depth) {
return (3 + depth * depth) / (2 - improving);
}

// Add correctionHistory value to raw staticEval and guarantee evaluation
// does not hit the tablebase range.
Value to_corrected_static_eval(Value v, const Worker& w, const Position& pos, Stack* ss) {
auto correction_value(const Worker& w, const Position& pos, Stack* ss) {
const Color us = pos.side_to_move();
const auto m = (ss - 1)->currentMove;
const auto pcv = w.pawnCorrectionHistory[us][pawn_structure_index<Correction>(pos)];
const auto macv = w.majorPieceCorrectionHistory[us][major_piece_index(pos)];
const auto micv = w.minorPieceCorrectionHistory[us][minor_piece_index(pos)];
const auto wnpcv = w.nonPawnCorrectionHistory[WHITE][us][non_pawn_index<WHITE>(pos)];
const auto bnpcv = w.nonPawnCorrectionHistory[BLACK][us][non_pawn_index<BLACK>(pos)];
int cntcv = 1;
const auto cntcv =
m.is_ok() ? (*(ss - 2)->continuationCorrectionHistory)[pos.piece_on(m.to_sq())][m.to_sq()]
: 0;

if (m.is_ok())
cntcv = int((*(ss - 2)->continuationCorrectionHistory)[pos.piece_on(m.to_sq())][m.to_sq()]);
return (6384 * pcv + 3583 * macv + 6492 * micv + 6725 * (wnpcv + bnpcv) + 5880 * cntcv);
}

const auto cv =
(6384 * pcv + 3583 * macv + 6492 * micv + 6725 * (wnpcv + bnpcv) + cntcv * 5880) / 131072;
v += cv;
return std::clamp(v, VALUE_TB_LOSS_IN_MAX_PLY + 1, VALUE_TB_WIN_IN_MAX_PLY - 1);
// Add correctionHistory value to raw staticEval and guarantee evaluation
// does not hit the tablebase range.
template<typename CorrectionType>
Value to_corrected_static_eval(Value v, const CorrectionType cv) {
return std::clamp(v + cv / 131072, VALUE_TB_LOSS_IN_MAX_PLY + 1, VALUE_TB_WIN_IN_MAX_PLY - 1);
}

// History and stats update bonus, based on depth
Expand Down Expand Up @@ -538,7 +539,7 @@ Value Search::Worker::search(

// Dive into quiescence search when the depth reaches zero
if (depth <= 0)
return qsearch < PvNode ? PV : NonPV > (pos, ss, alpha, beta);
return qsearch<PvNode ? PV : NonPV>(pos, ss, alpha, beta);

// Limit the depth if extensions made it too large
depth = std::min(depth, MAX_PLY - 1);
Expand Down Expand Up @@ -713,7 +714,8 @@ Value Search::Worker::search(
}

// Step 6. Static evaluation of the position
Value unadjustedStaticEval = VALUE_NONE;
Value unadjustedStaticEval = VALUE_NONE;
const auto correctionValue = correction_value(*thisThread, pos, ss);
if (ss->inCheck)
{
// Skip early pruning when in check
Expand All @@ -738,8 +740,7 @@ Value Search::Worker::search(
else if (PvNode)
Eval::NNUE::hint_common_parent_position(pos, networks[numaAccessToken], refreshTable);

ss->staticEval = eval =
to_corrected_static_eval(unadjustedStaticEval, *thisThread, pos, ss);
ss->staticEval = eval = to_corrected_static_eval(unadjustedStaticEval, correctionValue);

// ttValue can be used as a better position evaluation (~7 Elo)
if (is_valid(ttData.value)
Expand All @@ -750,8 +751,7 @@ Value Search::Worker::search(
{
unadjustedStaticEval =
evaluate(networks[numaAccessToken], pos, refreshTable, thisThread->optimism[us]);
ss->staticEval = eval =
to_corrected_static_eval(unadjustedStaticEval, *thisThread, pos, ss);
ss->staticEval = eval = to_corrected_static_eval(unadjustedStaticEval, correctionValue);

// Static evaluation is saved as it was before adjustment by correction history
ttWriter.write(posKey, VALUE_NONE, ss->ttPv, BOUND_NONE, DEPTH_UNSEARCHED, Move::none(),
Expand Down Expand Up @@ -1161,6 +1161,10 @@ Value Search::Worker::search(

// These reduction adjustments have no proven non-linear scaling

r += 330;

r -= std::min(std::abs(correctionValue) / 32768, 2048);

// Increase reduction for cut nodes (~4 Elo)
if (cutNode)
r += 2518 - (ttData.depth >= depth && ss->ttPv) * 991;
Expand Down Expand Up @@ -1532,7 +1536,8 @@ Value Search::Worker::qsearch(Position& pos, Stack* ss, Value alpha, Value beta)
return ttData.value;

// Step 4. Static evaluation of the position
Value unadjustedStaticEval = VALUE_NONE;
Value unadjustedStaticEval = VALUE_NONE;
const auto correctionValue = correction_value(*thisThread, pos, ss);
if (ss->inCheck)
bestValue = futilityBase = -VALUE_INFINITE;
else
Expand All @@ -1545,7 +1550,7 @@ Value Search::Worker::qsearch(Position& pos, Stack* ss, Value alpha, Value beta)
unadjustedStaticEval =
evaluate(networks[numaAccessToken], pos, refreshTable, thisThread->optimism[us]);
ss->staticEval = bestValue =
to_corrected_static_eval(unadjustedStaticEval, *thisThread, pos, ss);
to_corrected_static_eval(unadjustedStaticEval, correctionValue);

// ttValue can be used as a better position evaluation (~13 Elo)
if (is_valid(ttData.value) && !is_decisive(ttData.value)
Expand All @@ -1560,7 +1565,7 @@ Value Search::Worker::qsearch(Position& pos, Stack* ss, Value alpha, Value beta)
? evaluate(networks[numaAccessToken], pos, refreshTable, thisThread->optimism[us])
: -(ss - 1)->staticEval;
ss->staticEval = bestValue =
to_corrected_static_eval(unadjustedStaticEval, *thisThread, pos, ss);
to_corrected_static_eval(unadjustedStaticEval, correctionValue);
}

// Stand pat. Return immediately if static value is at least beta
Expand Down

0 comments on commit 397b2f3

Please sign in to comment.