Skip to content

Commit

Permalink
- update classification
Browse files Browse the repository at this point in the history
- change UCT function (which works)
- add executable
  • Loading branch information
franck-ledoux authored and nicolaslg committed Sep 5, 2024
1 parent 86d3311 commit 88c4398
Show file tree
Hide file tree
Showing 20 changed files with 62,070 additions and 255 deletions.
1 change: 0 additions & 1 deletion .github/workflows/continuous-ubuntu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ jobs:
run: >
pwd;
ls;
. /spack/share/spack/setup-env.sh;
spack load py-pytest;
cmake /__w/gmds/gmds -DCMAKE_BUILD_TYPE=${{ matrix.config }}
-DWITH_CODE_COVERAGE:BOOL=ON
Expand Down
6 changes: 6 additions & 0 deletions mctsblock/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#==============================================================================
set(MCTS_DATA_DIR "" CACHE INTERNAL "Data directory where are input files")
#==============================================================================
# LIBRARY DEFINTION (SOURCE FILES)
#==============================================================================
# Nommer tout en GMDS_MODULE_NAME, GMDS_SRC, ... dans les composants
Expand Down Expand Up @@ -35,6 +37,10 @@ set(GMDS_SRC
spam/src/MCTSTree.cpp
)

configure_file(
config.h.in
config.h
@ONLY)
#==============================================================================
add_library(${GMDS_LIB} ${GMDS_INC} ${GMDS_SRC})
#==============================================================================
Expand Down
1 change: 1 addition & 0 deletions mctsblock/config.h.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#cmakedefine MCTS_DATA_DIR "@MCTS_DATA_DIR@"
9 changes: 9 additions & 0 deletions mctsblock/inc/gmds/mctsblock/Blocking.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,11 @@ class LIB_GMDS_MCTSBLOCK_API Blocking
* @param[out] AFaceIds ids of boundary faces
*/
void extract_boundary(std::set<TCellID> &ANodeIds, std::set<TCellID> &AEdgeIds, std::set<TCellID> &AFaceIds);

/**@brief This method reset all the nodes, edges and faces classification data to geom dim -1, geom id 4
*/
void reset_classification();

/** Return the info for the node of id @p ANodeId
* @param[in] ANodeId topological node id
* @return a tuple where the first parameter is the geom_dim, the second its geom_id, and the third is location
Expand Down Expand Up @@ -731,6 +736,10 @@ class LIB_GMDS_MCTSBLOCK_API Blocking
*/
void init_from_mesh(Mesh &ACellMesh);


/**@brief intiialize the block structure from the bounding box of the geom model
*/
void init_from_bounding_box();
/**@brief Convert the block structure into a gmds cellular mesh. The provided
* mesh @p ACellMesh must have the following characteristics:
* - DIM3, N, E, F, R, R2N, F2N, E2N
Expand Down
5 changes: 5 additions & 0 deletions mctsblock/inc/gmds/mctsblock/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ class LIB_GMDS_MCTSBLOCK_API Graph
void computeDijkstra(TCellID ASrcNode);
//Be careful, input ids are in the global input space numbering, not the local one
void setWeight(const TCellID AN1, const TCellID AN2, const double AW);
/**
*
* @return for each node of the graph D, you get the path from the source node S to D.
* The first item of the vector is D, and the last one S.
*/
std::map<TCellID , std::vector<TCellID> > getShortestPath();
std::map<TCellID , double > getShortestPathWeights();
private:
Expand Down
5 changes: 4 additions & 1 deletion mctsblock/spam/inc/mcts/MCTSAgent.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class MCTSAgent {
DRAW,
LOST
};

/**@brief Constructor
* @param ARewardFunction reward function to evaluate a state
* @param ASelectFunction function to pick a child node during the selection process
Expand Down Expand Up @@ -71,14 +72,16 @@ class MCTSAgent {
* @return the "best" solution
*/
std::shared_ptr<IState> get_best_solution();
std::shared_ptr<IState> get_best_winning_solution();


/**@brief Returns the best child of the root. We consider
* here the best solution as being the most visited node at each level.
*
* @return the "best" root child
*/
std::shared_ptr<IState> get_most_visited_child();
std::shared_ptr<IState> get_most_visited_child();
std::shared_ptr<IState> get_most_winning_child();
/**@brief provides the number of iterations done by the algorithm
*/
int get_nb_iterations() const {return m_nb_iterations;}
Expand Down
3 changes: 2 additions & 1 deletion mctsblock/spam/inc/mcts/MCTSTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class MCTSTree{
/**@brief Returns the most visited child of the current root node
* @return a node, or subtree
*/
MCTSTree* get_most_visited_child() const;
MCTSTree* get_most_visited_child() const;
MCTSTree* get_most_winning_child() const;

/**@brief Create a new child that will be obtained by applying an untried action
* @return a child node obtained from applying an untried action on the current state
Expand Down
34 changes: 32 additions & 2 deletions mctsblock/spam/src/MCTSAgent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ std::shared_ptr<IState> MCTSAgent::get_best_solution() {
return node->get_state();
}

/*---------------------------------------------------------------------------*/
std::shared_ptr<IState> MCTSAgent::get_best_winning_solution()
{ const MCTSTree* node = m_tree;
while (!node->is_terminal() && node->has_children()){
node=node->get_most_winning_child();
}
return node->get_state();
}

/*---------------------------------------------------------------------------*/
std::shared_ptr<IState> MCTSAgent::get_most_visited_child() {
const MCTSTree* node = m_tree;
Expand All @@ -49,6 +58,16 @@ std::shared_ptr<IState> MCTSAgent::get_most_visited_child() {
}
return node->get_state();
}

/*---------------------------------------------------------------------------*/
std::shared_ptr<IState> MCTSAgent::get_most_winning_child() {
const MCTSTree* node = m_tree;

if (!node->is_terminal() && node->has_children()){
node=node->get_most_winning_child();
}
return node->get_state();
}
/*---------------------------------------------------------------------------*/
MCTSTree* MCTSAgent::expand(MCTSTree* ANode) {
if(!ANode->is_fully_expanded() && !ANode->is_terminal())
Expand All @@ -64,14 +83,25 @@ MCTSTree* MCTSAgent::expand(MCTSTree* ANode) {
std::pair<double,MCTSAgent::GAME_RESULT> MCTSAgent::simulate(MCTSTree* ANode) {
auto state =ANode->get_state();

if(!ANode->is_terminal()) {
for (int d = 0; d < m_simulation_depth; d++) {
//we first check that we are not a winner!!
if(state->win())
std::make_pair(m_reward_function->evaluate(state),WIN);

bool found_win=false;
bool found_lost =false;
if(!state->win() && !ANode->is_terminal()) {
//TODO checker cette boucle.
for (int d = 0; d < m_simulation_depth && !found_win && !found_lost; d++) {
if (!state->is_terminal()) {
auto a = get_random_action(state);
if(a== nullptr)
exit(55);

state = a->apply_on(state);
if (state->win())
found_win=true;
else if (state->lost())
found_lost=true;
}
}
}
Expand Down
31 changes: 16 additions & 15 deletions mctsblock/spam/src/MCTSSelectionFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@ MCTSTree* UCBSelectionFunction::select(MCTSTree *ANode) const {
// iterate all immediate children and find best UTC score
auto children = ANode->get_children();
auto num_children = children.size();
for(auto i = 0; i < num_children; i++) {
auto child = children[i];
float uct_exploitation = (float)child->number_win / (child->number_visits + FLT_EPSILON);
float uct_exploration = sqrt(log((float)ANode->number_visits + 1) / (child->number_visits + FLT_EPSILON) );
float uct_score = uct_exploitation + m_c * uct_exploration;
for(auto i = 0; i < num_children; i++) {
auto child = children[i];
auto np = (double)ANode->number_visits + 1;
auto ni = (double)child->number_visits + FLT_EPSILON;
auto vi = (double)child->cumulative_reward;
auto utc_score= vi/ni + m_c * sqrt(2*log(np) / ni);

if(uct_score > best_utc_score) {
best_utc_score = uct_score;
best_node = child;
}
}
if(utc_score > best_utc_score) {
best_utc_score = utc_score;
best_node = child;
}
}
if(best_node== nullptr)
throw std::runtime_error("Error when getting the best child of a node");

Expand All @@ -53,13 +54,13 @@ MCTSTree* SPUCTSelectionFunction::select(MCTSTree *ANode) const {

for(auto i = 0; i < num_children; i++) {
auto child = children[i];
auto tN = (double)ANode->number_visits + 1;
auto tNi = (double)child->number_visits + FLT_EPSILON;
auto w = (double)child->number_win;
auto utc= w/tNi + m_c * sqrt(log(tN) / tNi);
auto np = (double)ANode->number_visits + 1;
auto ni = (double)child->number_visits + FLT_EPSILON;
auto vi = (double)child->cumulative_reward;
auto utc= vi/ni + m_c * sqrt(2*log(np) / ni);

auto sum_x2 = child->sq_cumulative_reward;
auto utc_single = sqrt((sum_x2 - tNi* pow(w/tNi,2) +m_d) / tNi);
auto utc_single = sqrt((sum_x2 - ni* pow(vi/ni,2) +m_d) / ni);
auto utc_score = utc + utc_single;

if(utc_score > best_score) {
Expand Down
38 changes: 38 additions & 0 deletions mctsblock/spam/src/MCTSTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ MCTSTree* MCTSTree::get_most_visited_child() const {

// iterate all children and find most visited
for(auto c: m_children) {

std::cout<<"Child "<<c<<" from action "<<c->get_action()->get_description()<<" (visits, win, lost, draw, cumul. reward): "<<c->number_visits<<", "<<c->number_win<<", "<<c->number_lost<<", "<<c->number_draw<<", "<<c->cumulative_reward<<std::endl;

if(c->number_visits > most_visits) {
Expand All @@ -77,6 +78,43 @@ MCTSTree* MCTSTree::get_most_visited_child() const {

return best_node;
}

/*---------------------------------------------------------------------------*/
MCTSTree* MCTSTree::get_most_winning_child() const {
int most_wins = -1;
MCTSTree* best_node = nullptr;

// iterate all children and find most visited
for(auto c: m_children) {
std::cout<<"Child "<<c<<" from action "<<c->get_action()->get_description()<<" (visits, win, lost, draw, cumul. reward): "<<c->number_visits<<", "<<c->number_win<<", "<<c->number_lost<<", "<<c->number_draw<<", "<<c->cumulative_reward<<std::endl;

if(c->number_win > most_wins) {
most_wins = c->number_win;
best_node = c;
}
}
if (most_wins==0){
//look for draws
most_wins=-1;
// iterate all children and find most visited
for(auto c: m_children) {
if(c->number_draw > most_wins) {
most_wins = c->number_draw;
best_node = c;
}
}
}
if(best_node== nullptr)
throw std::runtime_error("Error when visiting children");
std::cout<<"--> Best "<<best_node<<" (visits, win, lost, draw, cumul. reward): "
<<best_node->number_visits<<", "
<<best_node->number_win<<", "
<<best_node->number_lost<<", "
<<best_node->number_draw<<", "
<<best_node->cumulative_reward<<std::endl;

return best_node;
}
/*---------------------------------------------------------------------------*/
MCTSTree* MCTSTree::get_child(const int AI) const {
return m_children[AI];
Expand Down
Loading

0 comments on commit 88c4398

Please sign in to comment.