Skip to content

Commit

Permalink
initial-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Belliveau committed Jan 24, 2025
1 parent 7b0ef8d commit 570829a
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 14 deletions.
4 changes: 3 additions & 1 deletion src/engine_classic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ void EngineClassic::PopulateOptions(OptionsParser* options) {
options->HideAllOptions();
options->UnhideOption(kThreadsOptionId);
options->UnhideOption(SharedBackendParams::kWeightsId);
options->UnhideOption(SharedBackendParams::kOpponentWeightsId);
options->UnhideOption(classic::SearchParams::kContemptId);
options->UnhideOption(classic::SearchParams::kMultiPvId);
}
Expand Down Expand Up @@ -156,6 +157,7 @@ void EngineClassic::UpdateFromUciOptions() {
NetworkFactory::BackendConfiguration(options_);
if (network_configuration_ != network_configuration) {
network_ = NetworkFactory::LoadNetwork(options_);
opponent_network_ = NetworkFactory::LoadOpponentNetwork(options_);
network_configuration_ = network_configuration;
}

Expand Down Expand Up @@ -385,7 +387,7 @@ void EngineClassic::Go(const GoParams& params) {

auto stopper = time_manager_->GetStopper(params, *tree_.get());
search_ = std::make_unique<classic::Search>(
*tree_, network_.get(), std::move(responder),
*tree_, network_.get(), opponent_network_.get(), std::move(responder),
StringsToMovelist(params.searchmoves, tree_->HeadPosition().GetBoard()),
*move_start_time_, std::move(stopper), params.infinite, params.ponder,
options_, &cache_, syzygy_tb_.get());
Expand Down
1 change: 1 addition & 0 deletions src/engine_classic.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class EngineClassic : public EngineControllerBase {
std::unique_ptr<classic::NodeTree> tree_;
std::unique_ptr<SyzygyTablebase> syzygy_tb_;
std::unique_ptr<Network> network_;
std::unique_ptr<Network> opponent_network_;
NNCache cache_;

// Store current TB and network settings to track when they change so that
Expand Down
22 changes: 22 additions & 0 deletions src/neural/factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ std::unique_ptr<Network> NetworkFactory::Create(
NetworkFactory::BackendConfiguration::BackendConfiguration(
const OptionsDict& options)
: weights_path(options.Get<std::string>(SharedBackendParams::kWeightsId)),
opponent_weights_path(
options.Get<std::string>(SharedBackendParams::kOpponentWeightsId)),
backend(options.Get<std::string>(SharedBackendParams::kBackendId)),
backend_options(
options.Get<std::string>(SharedBackendParams::kBackendOptionsId)) {}
Expand Down Expand Up @@ -103,4 +105,24 @@ std::unique_ptr<Network> NetworkFactory::LoadNetwork(
return ptr;
}

std::unique_ptr<Network> NetworkFactory::LoadOpponentNetwork(
const OptionsDict& options) {
std::string net_path =
options.Get<std::string>(SharedBackendParams::kOpponentWeightsId);
const std::string backend =
options.Get<std::string>(SharedBackendParams::kBackendId);
const std::string backend_options =
options.Get<std::string>(SharedBackendParams::kBackendOptionsId);

std::optional<WeightsFile> weights;
if (!net_path.empty()) weights = LoadWeights(net_path);
OptionsDict network_options(&options);
network_options.AddSubdictFromString(backend_options);

auto ptr = NetworkFactory::Get()->Create(backend, std::move(weights),
network_options);
network_options.CheckAllOptionsRead(backend);
return ptr;
}

} // namespace lczero
6 changes: 4 additions & 2 deletions src/neural/factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,22 @@ class NetworkFactory {
// Helper function to load the network from the options. Returns nullptr
// if no network options changed since the previous call.
static std::unique_ptr<Network> LoadNetwork(const OptionsDict& options);
static std::unique_ptr<Network> LoadOpponentNetwork(const OptionsDict& options);

struct BackendConfiguration {
BackendConfiguration() = default;
BackendConfiguration(const OptionsDict& options);
std::string weights_path;
std::string opponent_weights_path;
std::string backend;
std::string backend_options;
bool operator==(const BackendConfiguration& other) const;
bool operator!=(const BackendConfiguration& other) const {
return !operator==(other);
}
bool operator<(const BackendConfiguration& other) const {
return std::tie(weights_path, backend, backend_options) <
std::tie(other.weights_path, other.backend, other.backend_options);
return std::tie(weights_path, opponent_weights_path, backend, backend_options) <
std::tie(other.weights_path, other.opponent_weights_path, other.backend, other.backend_options);
}
};

Expand Down
8 changes: 8 additions & 0 deletions src/neural/shared_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ const OptionId SharedBackendParams::kWeightsId{
"makes it search in ./ and ./weights/ subdirectories for the latest (by "
"file date) file which looks like weights.",
'w'};
const OptionId SharedBackendParams::kOpponentWeightsId{
"opponent_weights", "OpponentWeightsFile",
"Path from which to load network weights.\nSetting it to <autodiscover> "
"makes it search in ./ and ./weights/ subdirectories for the latest (by "
"file date) file which looks like weights.",
'O'};
const OptionId SharedBackendParams::kBackendId{
"backend", "Backend", "Neural network computational backend to use.", 'b'};
const OptionId SharedBackendParams::kBackendOptionsId{
Expand All @@ -68,6 +74,8 @@ void SharedBackendParams::Populate(OptionsParser* options) {
#else
options->Add<StringOption>(SharedBackendParams::kWeightsId) = kAutoDiscover;
#endif
options->Add<StringOption>(SharedBackendParams::kOpponentWeightsId) =
kAutoDiscover;
const auto backends = NetworkFactory::Get()->GetBackendsList();
options->Add<ChoiceOption>(SharedBackendParams::kBackendId, backends) =
backends.empty() ? "<none>" : backends[0];
Expand Down
1 change: 1 addition & 0 deletions src/neural/shared_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ struct SharedBackendParams {
static const OptionId kPolicySoftmaxTemp;
static const OptionId kHistoryFill;
static const OptionId kWeightsId;
static const OptionId kOpponentWeightsId;
static const OptionId kBackendId;
static const OptionId kBackendOptionsId;
static const OptionId kNNCacheSizeId;
Expand Down
103 changes: 95 additions & 8 deletions src/search/classic/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
#include <sstream>
#include <thread>

#include "search/classic/node.h"
#include "neural/cache.h"
#include "neural/encoder.h"
#include "search/classic/node.h"
#include "utils/fastmath.h"
#include "utils/random.h"
#include "utils/spinhelper.h"
Expand Down Expand Up @@ -151,6 +151,7 @@ class MEvaluator {
} // namespace

Search::Search(const NodeTree& tree, Network* network,
Network* opponent_network,
std::unique_ptr<UciResponder> uci_responder,
const MoveList& searchmoves,
std::chrono::steady_clock::time_point start_time,
Expand All @@ -164,6 +165,7 @@ Search::Search(const NodeTree& tree, Network* network,
syzygy_tb_(syzygy_tb),
played_history_(tree.GetPositionHistory()),
network_(network),
opponent_network_(opponent_network),
params_(options),
searchmoves_(searchmoves),
start_time_(start_time),
Expand Down Expand Up @@ -1640,17 +1642,102 @@ void SearchWorker::PickNodesToExtendTask(

int max_limit = std::numeric_limits<int>::max();

// 1) Determine root color
bool root_is_black = search_->played_history_.IsBlackToMove();

current_path.push_back(-1);
while (current_path.size() > 0) {
// Need to do n visits, where n is either collision_limit, or comes from
// visits_to_perform for the current path.
int cur_limit = collision_limit;
if (current_path.size() > 1) {
cur_limit =
(*visits_to_perform.back())[current_path[current_path.size() - 2]];
}

// 2) Check if the node color is “opponent” of the root.
// (Odd/even ply from root decides side to move.)
bool node_is_black = (((int)current_path.size() + base_depth) % 2)
? !root_is_black
: root_is_black;
bool is_opponent_node = (node_is_black != root_is_black);

if (is_opponent_node) {
// -------------------------------
// *** Opponent logic: do forced 1-move with opponent_network_. ***
// -------------------------------
// Evaluate position with opponent_network_:
int transform;
auto planes = EncodePositionForNN(
search_->opponent_network_->GetCapabilities().input_format,
workspace->history, 8, params_.GetHistoryFill(), &transform);

auto opp_computation = search_->opponent_network_->NewComputation();
opp_computation->AddInput(std::move(planes));
opp_computation->ComputeBlocking();

// pick best policy among node->Edges():
float best_pol = -99999.0f;
int best_idx = -1;
int idx = 0;
for (auto edge_it : node->Edges()) {
float p = opp_computation->GetPVal(
0, edge_it.GetMove().as_nn_index(transform));
if (p > best_pol) {
best_pol = p;
best_idx = idx;
}
idx++;
}
if (best_idx < 0) {
// fallback if no moves
node = node->GetParent();
current_path.pop_back();
continue;
}
// Grab the forced child (the best move) by iterating node->Edges():
Node::Iterator forced_child = node->Edges();
for (int i = 0; i < best_idx; i++) {
++forced_child;
}

// Spawn (or reuse) the child node for that edge:
Node* child_node = forced_child.GetOrSpawnNode(node);

// Optionally do direct visits and collisions:
// We use 'cur_limit' as the number of visits we intended for this node.
int forced_visits = cur_limit;

// If brand-new or unvisited, we can do one immediate "visit":
if (child_node->TryStartScoreUpdate()) {
forced_visits -= 1;
// Make sure we queue it for a real NN inference or backup:
receiver->push_back(NodeToProcess::Visit(
child_node, (uint16_t)(current_path.size() + base_depth)));
}

// If we still have more visits left, mark them as collisions:
if (forced_visits > 0) {
receiver->push_back(NodeToProcess::Collision(
child_node, (uint16_t)(current_path.size() + base_depth),
forced_visits));
}

// 2) "Move the search in that direction":
// Tell Lc0's search loop to descend one ply deeper into child_node.
// That means we do the standard "found child" approach:
current_path.back() = best_idx; // We pick that edge index
current_path.push_back(-1); // Push a new level
node = child_node; // Descend
is_root_node = false;

// Finally, 'continue;' so we skip the usual multi-move logic for this
// node
continue;
}

// First prepare visits_to_perform.
if (current_path.back() == -1) {
// Need to do n visits, where n is either collision_limit, or comes from
// visits_to_perform for the current path.
int cur_limit = collision_limit;
if (current_path.size() > 1) {
cur_limit =
(*visits_to_perform.back())[current_path[current_path.size() - 2]];
}
// First check if node is terminal or not-expanded. If either than create
// a collision of appropriate size and pop current_path.
if (node->GetN() == 0 || node->IsTerminal()) {
Expand Down
4 changes: 3 additions & 1 deletion src/search/classic/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ namespace classic {

class Search {
public:
Search(const NodeTree& tree, Network* network,
Search(const NodeTree& tree, Network* network, Network* opponent_network,
std::unique_ptr<UciResponder> uci_responder,
const MoveList& searchmoves,
std::chrono::steady_clock::time_point start_time,
Expand Down Expand Up @@ -169,6 +169,7 @@ class Search {
const PositionHistory& played_history_;

Network* const network_;
Network* const opponent_network_;
const SearchParams params_;
const MoveList searchmoves_;
const std::chrono::steady_clock::time_point start_time_;
Expand Down Expand Up @@ -471,6 +472,7 @@ class SearchWorker {
// List of nodes to process.
std::vector<NodeToProcess> minibatch_;
std::unique_ptr<CachingComputation> computation_;
std::unique_ptr<CachingComputation> opponent_computation_;
int task_workers_;
int target_minibatch_size_;
int max_out_of_order_;
Expand Down
2 changes: 1 addition & 1 deletion src/selfplay/game.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ void SelfPlayGame::Play(int white_threads, int black_threads, bool training,
}

search_ = std::make_unique<classic::Search>(
*tree_[idx], options_[idx].network, std::move(responder),
*tree_[idx], options_[idx].network, options_[idx].network, std::move(responder),
/* searchmoves */ MoveList(), std::chrono::steady_clock::now(),
std::move(stoppers), /* infinite */ false, /* ponder */ false,
*options_[idx].uci_options, options_[idx].cache, syzygy_tb);
Expand Down
2 changes: 1 addition & 1 deletion src/tools/benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ void Benchmark::Run() {

const auto start = std::chrono::steady_clock::now();
auto search = std::make_unique<classic::Search>(
tree, network.get(),
tree, network.get(), network.get(),
std::make_unique<CallbackUciResponder>(
std::bind(&Benchmark::OnBestMove, this, std::placeholders::_1),
std::bind(&Benchmark::OnInfo, this, std::placeholders::_1)),
Expand Down

0 comments on commit 570829a

Please sign in to comment.