Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 Use TT recalculated scores everywhere and move TT cutoffs to NegaMax and QSearch methods #1310

Merged
merged 6 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/Lynx.Dev/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1083,19 +1083,19 @@ static void TesSize(int size)
//var entry = transpositionTable.ProbeHash(position, maxDepth: 5, depth: 3, alpha: 1, beta: 2);

transpositionTable.RecordHash(position, position.StaticEvaluation().Score, depth: 5, ply: 3, score: +19, nodeType: NodeType.Alpha, move: 1234);
var entry = transpositionTable.ProbeHash(position, depth: 5, ply: 3, alpha: 20, beta: 30);
var entry = transpositionTable.ProbeHash(position, ply: 3);
Console.WriteLine(entry); // Expected 20

transpositionTable.RecordHash(position, position.StaticEvaluation().Score, depth: 5, ply: 3, score: +21, nodeType: NodeType.Alpha, move: 1234);
entry = transpositionTable.ProbeHash(position, depth: 5, ply: 3, alpha: 20, beta: 30);
entry = transpositionTable.ProbeHash(position, ply: 3);
Console.WriteLine(entry); // Expected 12_345_678

transpositionTable.RecordHash(position, position.StaticEvaluation().Score, depth: 5, ply: 3, score: +29, nodeType: NodeType.Beta, move: 1234);
entry = transpositionTable.ProbeHash(position, depth: 5, ply: 3, alpha: 20, beta: 30);
entry = transpositionTable.ProbeHash(position, ply: 3);
Console.WriteLine(entry); // Expected 12_345_678

transpositionTable.RecordHash(position, position.StaticEvaluation().Score, depth: 5, ply: 3, score: +31, nodeType: NodeType.Beta, move: 1234);
entry = transpositionTable.ProbeHash(position, depth: 5, ply: 3, alpha: 20, beta: 30);
entry = transpositionTable.ProbeHash(position, ply: 3);
Console.WriteLine(entry); // Expected 30
}

Expand Down
24 changes: 5 additions & 19 deletions src/Lynx/Model/TranspositionTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public void PrefetchTTEntry(Position position)
/// </summary>
/// <param name="ply">Ply</param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public (int Score, ShortMove BestMove, NodeType NodeType, int RawScore, int StaticEval) ProbeHash(Position position, int depth, int ply, int alpha, int beta)
public (int Score, ShortMove BestMove, NodeType NodeType, int StaticEval, int Depth) ProbeHash(Position position, int ply)
{
var ttIndex = CalculateTTIndex(position.UniqueIdentifier);
var entry = _tt[ttIndex];
Expand All @@ -69,25 +69,11 @@ public void PrefetchTTEntry(Position position)
return (EvaluationConstants.NoHashEntry, default, default, default, default);
}

var type = entry.Type;
var rawScore = entry.Score;
var score = EvaluationConstants.NoHashEntry;
// We want to translate the checkmate position relative to the saved node to our root position from which we're searching
// If the recorded score is a checkmate in 3 and we are at depth 5, we want to read checkmate in 8
var recalculatedScore = RecalculateMateScores(entry.Score, ply);

if (entry.Depth >= depth)
{
var recalculatedScore = RecalculateMateScores(rawScore, ply);

if (type == NodeType.Exact
|| (type == NodeType.Alpha && recalculatedScore <= alpha)
|| (type == NodeType.Beta && recalculatedScore >= beta))
{
// We want to translate the checkmate position relative to the saved node to our root position from which we're searching
// If the recorded score is a checkmate in 3 and we are at depth 5, we want to read checkmate in 8
score = recalculatedScore;
}
}

return (score, entry.Move, entry.Type, rawScore, entry.StaticEval);
return (recalculatedScore, entry.Move, entry.Type, entry.StaticEval, entry.Depth);
}

/// <summary>
Expand Down
9 changes: 9 additions & 0 deletions src/Lynx/Model/TranspositionTableElement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,17 @@ public enum NodeType : byte
#pragma warning restore S4022 // Enumerations should have "Int32" storage
{
Unknown, // Making it 0 instead of -1 because of default struct initialization

Exact,

/// <summary>
/// UpperBound
/// </summary>
Alpha,

/// <summary>
/// LowerBound
/// </summary>
Beta
}

Expand Down
36 changes: 26 additions & 10 deletions src/Lynx/Search/NegaMax.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,22 @@ private int NegaMax(int depth, int ply, int alpha, int beta, bool cutnode, Cance
ShortMove ttBestMove = default;
NodeType ttElementType = default;
int ttScore = default;
int ttRawScore = default;
int ttStaticEval = int.MinValue;
int ttDepth = default;

Debug.Assert(!pvNode || !cutnode);

if (!isRoot)
{
(ttScore, ttBestMove, ttElementType, ttRawScore, ttStaticEval) = _tt.ProbeHash(position, depth, ply, alpha, beta);
(ttScore, ttBestMove, ttElementType, ttStaticEval, ttDepth) = _tt.ProbeHash(position, ply);

// TT cutoffs
if (!pvNode && ttScore != EvaluationConstants.NoHashEntry)
if (!pvNode
&& ttScore != EvaluationConstants.NoHashEntry
&& ttDepth >= depth
&& (ttElementType == NodeType.Exact
|| (ttElementType == NodeType.Alpha && ttScore <= alpha)
|| (ttElementType == NodeType.Beta && ttScore >= beta)))
{
return ttScore;
}
Expand Down Expand Up @@ -124,9 +129,9 @@ private int NegaMax(int depth, int ply, int alpha, int beta, bool cutnode, Cance
// If the score is outside what the current bounds are, but it did match flag and depth,
// then we can trust that this score is more accurate than the current static evaluation,
// and we can update our static evaluation for better accuracy in pruning
if (ttElementType != default && ttElementType != (ttRawScore > staticEval ? NodeType.Alpha : NodeType.Beta))
if (ttElementType != default && ttElementType != (ttScore > staticEval ? NodeType.Alpha : NodeType.Beta))
{
staticEval = ttRawScore;
staticEval = ttScore;
}

bool isNotGettingCheckmated = staticEval > EvaluationConstants.NegativeCheckmateDetectionLimit;
Expand Down Expand Up @@ -190,7 +195,7 @@ private int NegaMax(int depth, int ply, int alpha, int beta, bool cutnode, Cance
&& staticEvalBetaDiff >= 0
&& !parentWasNullMove
&& phase > 2 // Zugzwang risk reduction: pieces other than pawn presents
&& (ttElementType != NodeType.Alpha || ttRawScore >= beta)) // TT suggests NMP will fail: entry must not be a fail-low entry with a score below beta - Stormphrax and Ethereal
&& (ttElementType != NodeType.Alpha || ttScore >= beta)) // TT suggests NMP will fail: entry must not be a fail-low entry with a score below beta - Stormphrax and Ethereal
{
var nmpReduction = Configuration.EngineSettings.NMP_BaseDepthReduction
+ ((depth + Configuration.EngineSettings.NMP_DepthIncrement) / Configuration.EngineSettings.NMP_DepthDivisor) // Clarity
Expand Down Expand Up @@ -531,16 +536,27 @@ public int QuiescenceSearch(int ply, int alpha, int beta, CancellationToken canc
var nextPvIndex = PVTable.Indexes[ply + 1];
_pVTable[pvIndex] = _defaultMove; // Nulling the first value before any returns

var ttProbeResult = _tt.ProbeHash(position, 0, ply, alpha, beta);
if (ttProbeResult.Score != EvaluationConstants.NoHashEntry)
var ttProbeResult = _tt.ProbeHash(position, ply);
var ttScore = ttProbeResult.Score;
var ttNodeType = ttProbeResult.NodeType;
var ttHit = ttNodeType != NodeType.Unknown;

// QS TT cutoff
Debug.Assert(ttProbeResult.Depth >= 0, "Assertion failed", "We would need to add it as a TT cutoff condition");

if (ttHit
&& (ttNodeType == NodeType.Exact
|| (ttNodeType == NodeType.Alpha && ttScore <= alpha)
|| (ttNodeType == NodeType.Beta && ttScore >= beta)))
{
return ttProbeResult.Score;
return ttScore;
}

ShortMove ttBestMove = ttProbeResult.BestMove;

_maxDepthReached[ply] = ply;

var staticEval = ttProbeResult.NodeType != NodeType.Unknown
var staticEval = ttHit // TODO check if static eval
? ttProbeResult.StaticEval
: position.StaticEvaluation(Game.HalfMovesWithoutCaptureOrPawnMove).Score;

Expand Down
14 changes: 6 additions & 8 deletions tests/Lynx.Test/Model/TranspositionTableTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ public void RecalculateMateScores(int evaluation, int depth, int expectedEvaluat
Assert.AreEqual(expectedEvaluation, TranspositionTable.RecalculateMateScores(evaluation, depth));
}

[TestCase(+19, NodeType.Alpha, +20, +30, +19)]
[TestCase(+21, NodeType.Alpha, +20, +30, NoHashEntry)]
[TestCase(+29, NodeType.Beta, +20, +30, NoHashEntry)]
[TestCase(+31, NodeType.Beta, +20, +30, +31)]
public void RecordHash_ProbeHash(int recordedEval, NodeType recordNodeType, int probeAlpha, int probeBeta, int expectedProbeEval)
[TestCase(+19, NodeType.Alpha, +19)]
[TestCase(+31, NodeType.Beta, +31)]
public void RecordHash_ProbeHash(int recordedEval, NodeType recordNodeType, int expectedProbeEval)
{
var position = new Position(Constants.InitialPositionFEN);
var transpositionTable = new TranspositionTable();
Expand All @@ -37,7 +35,7 @@ public void RecordHash_ProbeHash(int recordedEval, NodeType recordNodeType, int

transpositionTable.RecordHash(position, staticEval, depth: 5, ply: 3, score: recordedEval, nodeType: recordNodeType, move: 1234);

var ttEntry = transpositionTable.ProbeHash(position, depth: 5, ply: 3, alpha: probeAlpha, beta: probeBeta);
var ttEntry = transpositionTable.ProbeHash(position, ply: 3);
Assert.AreEqual(expectedProbeEval, ttEntry.Score);
Assert.AreEqual(staticEval, ttEntry.StaticEval);
}
Expand All @@ -52,7 +50,7 @@ public void RecordHash_ProbeHash_CheckmateSameDepth(int recordedEval)

transpositionTable.RecordHash(position, recordedEval, depth: 10, ply: sharedDepth, score: recordedEval, nodeType: NodeType.Exact, move: 1234);

var ttEntry = transpositionTable.ProbeHash(position, depth: 7, ply: sharedDepth, alpha: 50, beta: 100);
var ttEntry = transpositionTable.ProbeHash(position, ply: sharedDepth);
Assert.AreEqual(recordedEval, ttEntry.Score);
Assert.AreEqual(recordedEval, ttEntry.StaticEval);
}
Expand All @@ -68,7 +66,7 @@ public void RecordHash_ProbeHash_CheckmateDifferentDepth(int recordedEval, int r

transpositionTable.RecordHash(position, recordedEval, depth: 10, ply: recordedDeph, score: recordedEval, nodeType: NodeType.Exact, move: 1234);

var ttEntry = transpositionTable.ProbeHash(position, depth: 7, ply: probeDepth, alpha: 50, beta: 100);
var ttEntry = transpositionTable.ProbeHash(position, ply: probeDepth);
Assert.AreEqual(expectedProbeEval, ttEntry.Score);
Assert.AreEqual(recordedEval, ttEntry.StaticEval);
}
Expand Down
Loading