Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refator TT: return whole entry object as ref instead of a tuple of values #632

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions src/Lynx.Dev/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1084,21 +1084,23 @@ static void TesSize(int size)
//transpositionTable.RecordHash(position, depth: 3, maxDepth: 5, move: 1234, eval: +5, nodeType: NodeType.Alpha);
//var entry = transpositionTable.ProbeHash(position, maxDepth: 5, depth: 3, alpha: 1, beta: 2);

transpositionTable.RecordHash(mask, position, depth: 5, ply: 3, eval: +19, nodeType: NodeType.Alpha, move: 1234);
var entry = transpositionTable.ProbeHash(mask, position, depth: 5, ply: 3, alpha: 20, beta: 30);
Console.WriteLine(entry); // Expected 20
TranspositionTableElement ttEntry = default;

transpositionTable.RecordHash(mask, position, depth: 5, ply: 3, eval: +21, nodeType: NodeType.Alpha, move: 1234);
entry = transpositionTable.ProbeHash(mask, position, depth: 5, ply: 3, alpha: 20, beta: 30);
Console.WriteLine(entry); // Expected 12_345_678
transpositionTable.Save(mask, position, depth: 5, ply: 3, eval: +19, nodeType: NodeType.Alpha, move: 1234);
var eval = transpositionTable.Read(mask, position, depth: 5, ply: 3, alpha: 20, beta: 30, ref ttEntry);
Console.WriteLine(eval); // Expected 20

transpositionTable.RecordHash(mask, position, depth: 5, ply: 3, eval: +29, nodeType: NodeType.Beta, move: 1234);
entry = transpositionTable.ProbeHash(mask, position, depth: 5, ply: 3, alpha: 20, beta: 30);
Console.WriteLine(entry); // Expected 12_345_678
transpositionTable.Save(mask, position, depth: 5, ply: 3, eval: +21, nodeType: NodeType.Alpha, move: 1234);
eval = transpositionTable.Read(mask, position, depth: 5, ply: 3, alpha: 20, beta: 30, ref ttEntry);
Console.WriteLine(eval); // Expected 12_345_678

transpositionTable.RecordHash(mask, position, depth: 5, ply: 3, eval: +31, nodeType: NodeType.Beta, move: 1234);
entry = transpositionTable.ProbeHash(mask, position, depth: 5, ply: 3, alpha: 20, beta: 30);
Console.WriteLine(entry); // Expected 30
transpositionTable.Save(mask, position, depth: 5, ply: 3, eval: +29, nodeType: NodeType.Beta, move: 1234);
eval = transpositionTable.Read(mask, position, depth: 5, ply: 3, alpha: 20, beta: 30, ref ttEntry);
Console.WriteLine(eval); // Expected 12_345_678

transpositionTable.Save(mask, position, depth: 5, ply: 3, eval: +31, nodeType: NodeType.Beta, move: 1234);
eval = transpositionTable.Read(mask, position, depth: 5, ply: 3, alpha: 20, beta: 30, ref ttEntry);
Console.WriteLine(eval); // Expected 30
}

static void UnmakeMove()
Expand Down
23 changes: 13 additions & 10 deletions src/Lynx/Model/TranspositionTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,31 +105,32 @@ public static void ClearTranspositionTable(this TranspositionTable transposition
/// <param name="ply">Ply</param>
/// <param name="alpha"></param>
/// <param name="beta"></param>
/// <param name="entry"></param>
/// <returns></returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static (int Evaluation, ShortMove BestMove, NodeType NodeType) ProbeHash(this TranspositionTable tt, int ttMask, Position position, int depth, int ply, int alpha, int beta)
public static int Read(this TranspositionTable tt, int ttMask, Position position, int depth, int ply, int alpha, int beta, ref TranspositionTableElement entry)
{
if (!Configuration.EngineSettings.TranspositionTableEnabled)
{
return (EvaluationConstants.NoHashEntry, default, default);
return EvaluationConstants.NoHashEntry;
}

ref var entry = ref tt[position.UniqueIdentifier & ttMask];
ref TranspositionTableElement localEntry = ref tt[position.UniqueIdentifier & ttMask];

if ((position.UniqueIdentifier >> 48) != entry.Key)
if ((position.UniqueIdentifier >> 48) != localEntry.Key)
{
return (EvaluationConstants.NoHashEntry, default, default);
return EvaluationConstants.NoHashEntry;
}

var eval = EvaluationConstants.NoHashEntry;

if (entry.Depth >= depth)
if (localEntry.Depth >= depth)
{
// We want to translate the checkmate position relative to the saved node to our root position from which we're searching
// If the recorded score is a checkmate in 3 and we are at depth 5, we want to read checkmate in 8
var score = RecalculateMateScores(entry.Score, ply);
var score = RecalculateMateScores(localEntry.Score, ply);

eval = entry.Type switch
eval = localEntry.Type switch
{
NodeType.Exact => score,
NodeType.Alpha when score <= alpha => alpha,
Expand All @@ -138,7 +139,9 @@ public static (int Evaluation, ShortMove BestMove, NodeType NodeType) ProbeHash(
};
}

return (eval, entry.Move, entry.Type);
entry = localEntry;

return eval;
}

/// <summary>
Expand All @@ -153,7 +156,7 @@ public static (int Evaluation, ShortMove BestMove, NodeType NodeType) ProbeHash(
/// <param name="nodeType"></param>
/// <param name="move"></param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void RecordHash(this TranspositionTable tt, int ttMask, Position position, int depth, int ply, int eval, NodeType nodeType, Move? move = null)
public static void Save(this TranspositionTable tt, int ttMask, Position position, int depth, int ply, int eval, NodeType nodeType, Move? move = null)
{
if (!Configuration.EngineSettings.TranspositionTableEnabled)
{
Expand Down
36 changes: 18 additions & 18 deletions src/Lynx/Search/NegaMax.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ private int NegaMax(int depth, int ply, int alpha, int beta, bool parentWasNullM

bool isRoot = ply == 0;
bool pvNode = beta - alpha > 1;
ShortMove ttBestMove = default;
NodeType ttElementType = default;
TranspositionTableElement ttEntry = default;
int ttEvaluation = default;

if (!isRoot)
{
(ttEvaluation, ttBestMove, ttElementType) = _tt.ProbeHash(_ttMask, position, depth, ply, alpha, beta);
ttEvaluation = _tt.Read(_ttMask, position, depth, ply, alpha, beta, ref ttEntry);
if (ttEvaluation != EvaluationConstants.NoHashEntry)
{
return ttEvaluation;
Expand All @@ -55,7 +54,7 @@ private int NegaMax(int depth, int ply, int alpha, int beta, bool parentWasNullM
// so the search will be potentially expensive.
// Therefore, we search with reduced depth for now, expecting to record a TT move
// which we'll be able to use later for the full depth search
if (ttElementType == default && depth >= Configuration.EngineSettings.IIR_MinDepth)
if (ttEntry.Type == default && depth >= Configuration.EngineSettings.IIR_MinDepth)
{
--depth;
}
Expand All @@ -78,7 +77,7 @@ private int NegaMax(int depth, int ply, int alpha, int beta, bool parentWasNullM
}

var finalPositionEvaluation = Position.EvaluateFinalPosition(ply, isInCheck);
_tt.RecordHash(_ttMask, position, depth, ply, finalPositionEvaluation, NodeType.Exact);
_tt.Save(_ttMask, position, depth, ply, finalPositionEvaluation, NodeType.Exact);
return finalPositionEvaluation;
}

Expand All @@ -91,7 +90,7 @@ private int NegaMax(int depth, int ply, int alpha, int beta, bool parentWasNullM
&& staticEval >= beta
&& !parentWasNullMove
&& phase > 2 // Zugzwang risk reduction: pieces other than pawn presents
&& (ttElementType != NodeType.Alpha || ttEvaluation >= beta)) // TT suggests NMP will fail: entry must not be a fail-low entry with a score below beta - Stormphrax and Ethereal
&& (ttEntry.Type != NodeType.Alpha || ttEvaluation >= beta)) // TT suggests NMP will fail: entry must not be a fail-low entry with a score below beta - Stormphrax and Ethereal
{
var nmpReduction = Configuration.EngineSettings.NMP_BaseDepthReduction + ((depth + 1) / 3); // Clarity

Expand Down Expand Up @@ -165,7 +164,7 @@ private int NegaMax(int depth, int ply, int alpha, int beta, bool parentWasNullM
_isFollowingPV = false;
for (int i = 0; i < pseudoLegalMoves.Length; ++i)
{
scores[i] = ScoreMove(pseudoLegalMoves[i], ply, isNotQSearch: true, ttBestMove);
scores[i] = ScoreMove(pseudoLegalMoves[i], ply, isNotQSearch: true, ttEntry.Move);

if (pseudoLegalMoves[i] == _pVTable[depth])
{
Expand All @@ -178,7 +177,7 @@ private int NegaMax(int depth, int ply, int alpha, int beta, bool parentWasNullM
{
for (int i = 0; i < pseudoLegalMoves.Length; ++i)
{
scores[i] = ScoreMove(pseudoLegalMoves[i], ply, isNotQSearch: true, ttBestMove);
scores[i] = ScoreMove(pseudoLegalMoves[i], ply, isNotQSearch: true, ttEntry.Move);
}
}

Expand Down Expand Up @@ -378,7 +377,7 @@ private int NegaMax(int depth, int ply, int alpha, int beta, bool parentWasNullM
}
}

_tt.RecordHash(_ttMask, position, depth, ply, beta, NodeType.Beta, bestMove);
_tt.Save(_ttMask, position, depth, ply, beta, NodeType.Beta, bestMove);

return beta; // TODO return evaluation?
}
Expand All @@ -401,11 +400,11 @@ private int NegaMax(int depth, int ply, int alpha, int beta, bool parentWasNullM
{
var eval = Position.EvaluateFinalPosition(ply, isInCheck);

_tt.RecordHash(_ttMask, position, depth, ply, eval, NodeType.Exact);
_tt.Save(_ttMask, position, depth, ply, eval, NodeType.Exact);
return eval;
}

_tt.RecordHash(_ttMask, position, depth, ply, alpha, nodeType, bestMove);
_tt.Save(_ttMask, position, depth, ply, alpha, nodeType, bestMove);

// Node fails low
return alpha;
Expand Down Expand Up @@ -441,12 +440,13 @@ public int QuiescenceSearch(int ply, int alpha, int beta)
var nextPvIndex = PVTable.Indexes[ply + 1];
_pVTable[pvIndex] = _defaultMove; // Nulling the first value before any returns

var ttProbeResult = _tt.ProbeHash(_ttMask, position, 0, ply, alpha, beta);
if (ttProbeResult.Evaluation != EvaluationConstants.NoHashEntry)
TranspositionTableElement ttEntry = default;
var ttProbeEvaluation = _tt.Read(_ttMask, position, 0, ply, alpha, beta, ref ttEntry);
if (ttProbeEvaluation != EvaluationConstants.NoHashEntry)
{
return ttProbeResult.Evaluation;
return ttProbeEvaluation;
}
ShortMove ttBestMove = ttProbeResult.BestMove;
ShortMove ttBestMove = ttEntry.Move;

_maxDepthReached[ply] = ply;

Expand Down Expand Up @@ -551,7 +551,7 @@ public int QuiescenceSearch(int ply, int alpha, int beta)
{
PrintMessage($"Pruning: {move} is enough to discard this line");

_tt.RecordHash(_ttMask, position, 0, ply, beta, NodeType.Beta, bestMove);
_tt.Save(_ttMask, position, 0, ply, beta, NodeType.Beta, bestMove);

return evaluation; // The refutation doesn't matter, since it'll be pruned
}
Expand All @@ -573,12 +573,12 @@ public int QuiescenceSearch(int ply, int alpha, int beta)
&& !MoveGenerator.CanGenerateAtLeastAValidMove(position))
{
var finalEval = Position.EvaluateFinalPosition(ply, position.IsInCheck());
_tt.RecordHash(_ttMask, position, 0, ply, finalEval, NodeType.Exact);
_tt.Save(_ttMask, position, 0, ply, finalEval, NodeType.Exact);

return finalEval;
}

_tt.RecordHash(_ttMask, position, 0, ply, alpha, nodeType, bestMove);
_tt.Save(_ttMask, position, 0, ply, alpha, nodeType, bestMove);

return alpha;
}
Expand Down
15 changes: 9 additions & 6 deletions tests/Lynx.Test/Model/TranspositionTableTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ public void RecordHash_ProbeHash(int recordedEval, NodeType recordNodeType, int
var (mask, length) = TranspositionTableExtensions.CalculateLength(Configuration.EngineSettings.TranspositionTableSize);
var transpositionTable = new TranspositionTableElement[length];

transpositionTable.RecordHash(mask, position, depth: 5, ply: 3, eval: recordedEval, nodeType: recordNodeType, move: 1234);
transpositionTable.Save(mask, position, depth: 5, ply: 3, eval: recordedEval, nodeType: recordNodeType, move: 1234);

Assert.AreEqual(expectedProbeEval, transpositionTable.ProbeHash(mask, position, depth: 5, ply: 3, alpha: probeAlpha, beta: probeBeta).Evaluation);
TranspositionTableElement ttEntry = default;
Assert.AreEqual(expectedProbeEval, transpositionTable.Read(mask, position, depth: 5, ply: 3, alpha: probeAlpha, beta: probeBeta, ref ttEntry));
}

[TestCase(CheckMateBaseEvaluation - 8 * CheckmateDepthFactor)]
Expand All @@ -98,9 +99,10 @@ public void RecordHash_ProbeHash_CheckmateSameDepth(int recordedEval)
var (mask, length) = TranspositionTableExtensions.CalculateLength(Configuration.EngineSettings.TranspositionTableSize);
var transpositionTable = new TranspositionTableElement[length];

transpositionTable.RecordHash(mask, position, depth: 10, ply: sharedDepth, eval: recordedEval, nodeType: NodeType.Exact, move: 1234);
transpositionTable.Save(mask, position, depth: 10, ply: sharedDepth, eval: recordedEval, nodeType: NodeType.Exact, move: 1234);

Assert.AreEqual(recordedEval, transpositionTable.ProbeHash(mask, position, depth: 7, ply: sharedDepth, alpha: 50, beta: 100).Evaluation);
TranspositionTableElement ttEntry = default;
Assert.AreEqual(recordedEval, transpositionTable.Read(mask, position, depth: 7, ply: sharedDepth, alpha: 50, beta: 100, ref ttEntry));
}

[TestCase(CheckMateBaseEvaluation - 8 * CheckmateDepthFactor, 5, 4, CheckMateBaseEvaluation - 7 * CheckmateDepthFactor)]
Expand All @@ -113,8 +115,9 @@ public void RecordHash_ProbeHash_CheckmateDifferentDepth(int recordedEval, int r
var (mask, length) = TranspositionTableExtensions.CalculateLength(Configuration.EngineSettings.TranspositionTableSize);
var transpositionTable = new TranspositionTableElement[length];

transpositionTable.RecordHash(mask, position, depth: 10, ply: recordedDeph, eval: recordedEval, nodeType: NodeType.Exact, move: 1234);
transpositionTable.Save(mask, position, depth: 10, ply: recordedDeph, eval: recordedEval, nodeType: NodeType.Exact, move: 1234);

Assert.AreEqual(expectedProbeEval, transpositionTable.ProbeHash(mask, position, depth: 7, ply: probeDepth, alpha: 50, beta: 100).Evaluation);
TranspositionTableElement ttEntry = default;
Assert.AreEqual(expectedProbeEval, transpositionTable.Read(mask, position, depth: 7, ply: probeDepth, alpha: 50, beta: 100, ref ttEntry));
}
}
Loading