Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new backend interface prototype #2092

Draft
wants to merge 18 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 0 additions & 47 deletions src/mcts/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,53 +127,6 @@ Move Edge::GetMove(bool as_opponent) const {
return m;
}

// Policy priors (P) are stored in a compressed 16-bit format.
//
// Source values are 32-bit floats:
// * bit 31 is sign (zero means positive)
// * bit 30 is sign of exponent (zero means nonpositive)
// * bits 29..23 are value bits of exponent
// * bits 22..0 are significand bits (plus a "virtual" always-on bit: s ∈ [1,2))
// The number is then sign * 2^exponent * significand, usually.
// See https://www.h-schmidt.net/FloatConverter/IEEE754.html for details.
//
// In compressed 16-bit value we store bits 27..12:
// * bit 31 is always off as values are always >= 0
// * bit 30 is always off as values are always < 2
// * bits 29..28 are only off for values < 4.6566e-10, assume they are always on
// * bits 11..0 are for higher precision, they are dropped leaving only 11 bits
// of precision
//
// When converting to compressed format, bit 11 is added to in order to make it
// a rounding rather than truncation.
//
// Out of 65556 possible values, 2047 are outside of [0,1] interval (they are in
// interval (1,2)). This is fine because the values in [0,1] are skewed towards
// 0, which is also exactly how the components of policy tend to behave (since
// they add up to 1).

// If the two assumed-on exponent bits (3<<28) are in fact off, the input is
// rounded up to the smallest value with them on. We accomplish this by
// subtracting the two bits from the input and checking for a negative result
// (the subtraction works despite crossing from exponent to significand). This
// is combined with the round-to-nearest addition (1<<11) into one op.
void Edge::SetP(float p) {
assert(0.0f <= p && p <= 1.0f);
constexpr int32_t roundings = (1 << 11) - (3 << 28);
int32_t tmp;
std::memcpy(&tmp, &p, sizeof(float));
tmp += roundings;
p_ = (tmp < 0) ? 0 : static_cast<uint16_t>(tmp >> 12);
}

float Edge::GetP() const {
// Reshift into place and set the assumed-set exponent bits.
uint32_t tmp = (static_cast<uint32_t>(p_) << 12) | (3 << 28);
float ret;
std::memcpy(&ret, &tmp, sizeof(uint32_t));
return ret;
}

std::string Edge::DebugString() const {
std::ostringstream oss;
oss << "Move: " << move_.as_string() << " p_: " << p_ << " GetP: " << GetP();
Expand Down
9 changes: 5 additions & 4 deletions src/mcts/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
#include "chess/board.h"
#include "chess/callbacks.h"
#include "chess/position.h"
#include "neural/cache.h"
#include "neural/encoder.h"
#include "proto/net.pb.h"
#include "utils/mutex.h"
#include "utils/pfloat16.h"

namespace lczero {

Expand Down Expand Up @@ -92,8 +92,9 @@ class Edge {

// Returns or sets value of Move policy prior returned from the neural net
// (but can be changed by adding Dirichlet noise). Must be in [0,1].
float GetP() const;
void SetP(float val);
float GetP() const { return p_; }
void SetP(float val) { p_ = val; }
void SetP(pfloat16 p) { p_ = p; }
borg323 marked this conversation as resolved.
Show resolved Hide resolved

// Debug information about the edge.
std::string DebugString() const;
Expand All @@ -106,7 +107,7 @@ class Edge {

// Probability that this move will be made, from the policy head of the neural
// network; compressed to a 16 bit format (5 bits exp, 11 bits significand).
uint16_t p_ = 0;
pfloat16 p_;
friend class Node;
};

Expand Down
139 changes: 43 additions & 96 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1258,8 +1258,10 @@ void SearchWorker::ExecuteOneIteration() {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
void SearchWorker::InitializeIteration(
std::unique_ptr<NetworkComputation> computation) {
computation_ = std::make_unique<CachingComputation>(std::move(computation),
search_->cache_);
computation_ = std::make_unique<CachingComputation>(
std::move(computation), search_->network_->GetCapabilities().input_format,
params_.GetHistoryFill(), params_.GetPolicySoftmaxTemp(),
params_.GetCacheHistoryLength() + 1, search_->cache_);
computation_->Reserve(target_minibatch_size_);
minibatch_.clear();
minibatch_.reserve(2 * target_minibatch_size_);
Expand Down Expand Up @@ -1418,16 +1420,8 @@ void SearchWorker::GatherMinibatch() {
// If there was no OOO, there can stil be collisions.
// There are no OOO though.
// Also terminals when OOO is disabled.
if (!minibatch_[i].nn_queried) continue;
if (minibatch_[i].is_cache_hit) {
// Since minibatch_[i] holds cache lock, this is guaranteed to succeed.
computation_->AddInputByHash(minibatch_[i].hash,
std::move(minibatch_[i].lock));
} else {
computation_->AddInput(minibatch_[i].hash,
std::move(minibatch_[i].input_planes),
std::move(minibatch_[i].probabilities_to_cache));
}
if (!minibatch_[i].nn_queried || minibatch_[i].is_cache_hit) continue;
computation_->AddInput(minibatch_[i].history, minibatch_[i].moves);
}

// Check for stop at the end so we have at least one node.
Expand Down Expand Up @@ -1470,30 +1464,22 @@ void SearchWorker::ProcessPickedTask(int start_idx, int end_idx,
// of the game), it means that we already visited this node before.
if (picked_node.IsExtendable()) {
// Node was never visited, extend it.
ExtendNode(node, picked_node.depth, picked_node.moves_to_visit, &history);
// Initialize position sequence with pre-move position.
history.Trim(search_->played_history_.GetLength());
history.Reserve(search_->played_history_.GetLength() +
picked_node.moves_to_visit.size());
for (size_t i = 0; i < picked_node.moves_to_visit.size(); i++) {
history.Append(picked_node.moves_to_visit[i]);
}

picked_node.moves = history.Last().GetBoard().GenerateLegalMoves();

ExtendNode(node, picked_node.depth, history, picked_node.moves);
if (!node->IsTerminal()) {
picked_node.nn_queried = true;
const auto hash = history.HashLast(params_.GetCacheHistoryLength() + 1);
picked_node.hash = hash;
picked_node.lock = NNCacheLock(search_->cache_, hash);
picked_node.is_cache_hit = picked_node.lock;
if (!picked_node.is_cache_hit) {
int transform;
picked_node.input_planes = EncodePositionForNN(
search_->network_->GetCapabilities().input_format, history, 8,
params_.GetHistoryFill(), &transform);
picked_node.probability_transform = transform;

std::vector<uint16_t>& moves = picked_node.probabilities_to_cache;
// Legal moves are known, use them.
moves.reserve(node->GetNumEdges());
for (const auto& edge : node->Edges()) {
moves.emplace_back(edge.GetMove().as_nn_index(transform));
}
} else {
picked_node.probability_transform = TransformForPosition(
search_->network_->GetCapabilities().input_format, history);
}
picked_node.is_cache_hit = computation_->CacheLookup(
history, picked_node.moves, &picked_node.entry);
if (!picked_node.is_cache_hit) picked_node.history = history;
}
}
if (params_.GetOutOfOrderEval() && picked_node.CanEvalOutOfOrder()) {
Expand Down Expand Up @@ -1939,19 +1925,11 @@ void SearchWorker::PickNodesToExtendTask(
}

void SearchWorker::ExtendNode(Node* node, int depth,
const std::vector<Move>& moves_to_node,
PositionHistory* history) {
// Initialize position sequence with pre-move position.
history->Trim(search_->played_history_.GetLength());
for (size_t i = 0; i < moves_to_node.size(); i++) {
history->Append(moves_to_node[i]);
}

const PositionHistory& history,
const MoveList& legal_moves) {
// We don't need the mutex because other threads will see that N=0 and
// N-in-flight=1 and will not touch this node.
const auto& board = history->Last().GetBoard();
auto legal_moves = board.GenerateLegalMoves();

const auto& board = history.Last().GetBoard();
// Check whether it's a draw/lose by position. Importantly, we must check
// these before doing the by-rule checks below.
if (legal_moves.empty()) {
Expand All @@ -1972,21 +1950,21 @@ void SearchWorker::ExtendNode(Node* node, int depth,
return;
}

if (history->Last().GetRule50Ply() >= 100) {
if (history.Last().GetRule50Ply() >= 100) {
node->MakeTerminal(GameResult::DRAW);
return;
}

const auto repetitions = history->Last().GetRepetitions();
const auto repetitions = history.Last().GetRepetitions();
// Mark two-fold repetitions as draws according to settings.
// Depth starts with 1 at root, so number of plies in PV is depth - 1.
if (repetitions >= 2) {
node->MakeTerminal(GameResult::DRAW);
return;
} else if (repetitions == 1 && depth - 1 >= 4 &&
params_.GetTwoFoldDraws() &&
depth - 1 >= history->Last().GetPliesSincePrevRepetition()) {
const auto cycle_length = history->Last().GetPliesSincePrevRepetition();
depth - 1 >= history.Last().GetPliesSincePrevRepetition()) {
const auto cycle_length = history.Last().GetPliesSincePrevRepetition();
// use plies since first repetition as moves left; exact if forced draw.
node->MakeTerminal(GameResult::DRAW, (float)cycle_length,
Node::Terminal::TwoFold);
Expand All @@ -1996,12 +1974,12 @@ void SearchWorker::ExtendNode(Node* node, int depth,
// Neither by-position or by-rule termination, but maybe it's a TB position.
if (search_->syzygy_tb_ && !search_->root_is_in_dtz_ &&
board.castlings().no_legal_castle() &&
history->Last().GetRule50Ply() == 0 &&
history.Last().GetRule50Ply() == 0 &&
(board.ours() | board.theirs()).count() <=
search_->syzygy_tb_->max_cardinality()) {
ProbeState state;
const WDLScore wdl =
search_->syzygy_tb_->probe_wdl(history->Last(), &state);
search_->syzygy_tb_->probe_wdl(history.Last(), &state);
// Only fail state means the WDL is wrong, probe_wdl may produce correct
// result with a stat other than OK.
if (state != FAIL) {
Expand Down Expand Up @@ -2037,35 +2015,23 @@ void SearchWorker::ExtendNode(Node* node, int depth,

// Returns whether node was already in cache.
bool SearchWorker::AddNodeToComputation(Node* node) {
const auto hash = history_.HashLast(params_.GetCacheHistoryLength() + 1);
if (search_->cache_->ContainsKey(hash)) {
if (computation_->CacheLookup(history_)) {
return true;
}
int transform;
auto planes =
EncodePositionForNN(search_->network_->GetCapabilities().input_format,
history_, 8, params_.GetHistoryFill(), &transform);

std::vector<uint16_t> moves;
MoveList moves;

if (node && node->HasChildren()) {
// Legal moves are known, use them.
moves.reserve(node->GetNumEdges());
for (const auto& edge : node->Edges()) {
moves.emplace_back(edge.GetMove().as_nn_index(transform));
moves.emplace_back(edge.GetMove());
}
} else {
// Cache pseudolegal moves. A bit of a waste, but faster.
const auto& pseudolegal_moves =
history_.Last().GetBoard().GeneratePseudolegalMoves();
moves.reserve(pseudolegal_moves.size());
for (auto iter = pseudolegal_moves.begin(), end = pseudolegal_moves.end();
iter != end; ++iter) {
moves.emplace_back(iter->as_nn_index(transform));
}
// Cache legal moves.
moves = history_.Last().GetBoard().GenerateLegalMoves();
}

computation_->AddInput(hash, std::move(planes), std::move(moves));
computation_->AddInput(history_, moves);
return false;
}

Expand Down Expand Up @@ -2197,6 +2163,10 @@ void SearchWorker::FetchMinibatchResults() {
// Populate NN/cached results, or terminal results, into nodes.
int idx_in_computation = 0;
for (auto& node_to_process : minibatch_) {
if (node_to_process.is_cache_hit) {
FetchSingleNodeResult(&node_to_process, node_to_process, 0);
continue;
}
FetchSingleNodeResult(&node_to_process, *computation_, idx_in_computation);
if (node_to_process.nn_queried) ++idx_in_computation;
}
Expand Down Expand Up @@ -2236,34 +2206,11 @@ void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process,
node_to_process->v = v;
node_to_process->d = d;
node_to_process->m = computation.GetMVal(idx_in_computation);
// ...and secondly, the policy data.
// Calculate maximum first.
float max_p = -std::numeric_limits<float>::infinity();
// Intermediate array to store values when processing policy.
// There are never more than 256 valid legal moves in any legal position.
std::array<float, 256> intermediate;
int counter = 0;
for (auto& edge : node->Edges()) {
float p = computation.GetPVal(
idx_in_computation,
edge.GetMove().as_nn_index(node_to_process->probability_transform));
intermediate[counter++] = p;
max_p = std::max(max_p, p);
}
float total = 0.0;
for (int i = 0; i < counter; i++) {
// Perform softmax and take into account policy softmax temperature T.
// Note that we want to calculate (exp(p-max_p))^(1/T) = exp((p-max_p)/T).
float p =
FastExp((intermediate[i] - max_p) / params_.GetPolicySoftmaxTemp());
intermediate[i] = p;
total += p;
}
counter = 0;
// Normalize P values to add up to 1.0.
const float scale = total > 0.0f ? 1.0f / total : 1.0f;
// ...and secondly, the policy data. The cache returns compressed values after
// softmax.
int idx = 0;
for (auto& edge : node->Edges()) {
edge.edge()->SetP(intermediate[counter++] * scale);
edge.edge()->SetP(computation.GetPVal(idx_in_computation, idx++));
}
// Add Dirichlet noise if enabled and at root.
if (params_.GetNoiseEpsilon() && node == search_->root_node_) {
Expand Down
34 changes: 10 additions & 24 deletions src/mcts/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "syzygy/syzygy.h"
#include "utils/logging.h"
#include "utils/mutex.h"
#include "utils/pfloat16.h"

namespace lczero {

Expand Down Expand Up @@ -331,18 +332,16 @@ class SearchWorker {
bool nn_queried = false;
bool is_cache_hit = false;
bool is_collision = false;
int probability_transform = 0;

// Details only populated in the multigather path.

// Only populated for visits,
std::vector<Move> moves_to_visit;

// Details that are filled in as we go.
uint64_t hash;
NNCacheLock lock;
std::vector<uint16_t> probabilities_to_cache;
InputPlanes input_planes;
CachedNNRequest entry;
MoveList moves;
PositionHistory history;
mutable int last_idx = 0;
bool ooo_completed = false;

Expand All @@ -361,26 +360,13 @@ class SearchWorker {
// Methods to allow NodeToProcess to conform as a 'Computation'. Only safe
// to call if is_cache_hit is true in the multigather path.

float GetQVal(int) const { return lock->q; }

float GetDVal(int) const { return lock->d; }
float GetQVal(int) const { return entry.q; }

float GetMVal(int) const { return lock->m; }
float GetDVal(int) const { return entry.d; }

float GetPVal(int, int move_id) const {
const auto& moves = lock->p;
float GetMVal(int) const { return entry.m; }

int total_count = 0;
while (total_count < moves.size()) {
// Optimization: usually moves are stored in the same order as queried.
const auto& move = moves[last_idx++];
if (last_idx == moves.size()) last_idx = 0;
if (move.first == move_id) return move.second;
++total_count;
}
assert(false); // Move not found.
return 0;
}
pfloat16 GetPVal(int, int move_ct) const { return entry.p[move_ct]; }

private:
NodeToProcess(Node* node, uint16_t depth, bool is_collision, int multivisit,
Expand Down Expand Up @@ -455,8 +441,8 @@ class SearchWorker {
void EnsureNodeTwoFoldCorrectForDepth(Node* node, int depth);
void ProcessPickedTask(int batch_start, int batch_end,
TaskWorkspace* workspace);
void ExtendNode(Node* node, int depth, const std::vector<Move>& moves_to_add,
PositionHistory* history);
void ExtendNode(Node* node, int depth, const PositionHistory& history,
const MoveList& legal_moves);
template <typename Computation>
void FetchSingleNodeResult(NodeToProcess* node_to_process,
const Computation& computation,
Expand Down
Loading