diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9a56d84..82549cd 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -18,7 +18,7 @@ jobs: # This workflow contains a single job called "build" build: # The type of runner that the job will run on - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 # Steps represent a sequence of tasks that will be executed as part of the job steps: @@ -32,13 +32,15 @@ jobs: # Runs a set of commands using the runners shell - name: install dependencies run: | - sudo apt install -y qt5-default qtbase5-dev qt5-qmake build-essential wget + sudo apt install -y qtbase5-dev qt5-qmake qttools5-dev build-essential wget cmake # Runs a set of commands using the runners shell - name: make appimage run: | - ls - ./build-AppImage.sh + ls + cmake --version + cmake -DCMAKE_BUILD_TYPE=Release -S . -B build + make -C build -j - uses: actions/upload-artifact@v2 with: diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..84c048a --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/build/ diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..72b1d7c --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,134 @@ +cmake_minimum_required(VERSION 3.20) + +option(USE_CUDA "" OFF) +option(QT_GUI "" ON) +option(BUILD_API "" ON) + +if(USE_CUDA) + project(TexasSolver LANGUAGES CXX CUDA) +else() + project(TexasSolver LANGUAGES CXX) +endif() + +set(CMAKE_CXX_STANDARD 20) +# set(CMAKE_CXX_STANDARD_REQUIRED ON) +if(MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") +endif() +message("CMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}") + +set(CMAKE_INCLUDE_CURRENT_DIR ON) +include_directories(include) + +find_package(OpenMP REQUIRED) +message("OpenMP_CXX_FLAGS=${OpenMP_CXX_FLAGS}") + +file(GLOB_RECURSE SRC src/*.cpp) +file(GLOB_RECURSE EXC_SRC src/*format.cpp) +file(GLOB GUI_SRC *.cpp src/ui/*.cpp src/runtime/qsolverjob.cpp) +file(GLOB API_SRC src/api.cpp) +file(GLOB EXE_SRC src/console.cpp) +list(REMOVE_ITEM SRC ${EXC_SRC} ${GUI_SRC} ${EXE_SRC} ${API_SRC}) +# message("SRC=${SRC}") +# message("EXC_SRC=${EXC_SRC}") +# message("GUI_SRC=${GUI_SRC}") +# message("API_SRC=${API_SRC}") +# message("EXE_SRC=${EXE_SRC}") + +if(USE_CUDA) + add_definitions(-DUSE_CUDA) + file(GLOB_RECURSE CUDA_SRC src/*.cu) + message("CUDA_SRC=${CUDA_SRC}") + + set(CMAKE_CUDA_STANDARD 20) + # set(CMAKE_CUDA_STANDARD_REQUIRED ON) + message("CMAKE_MINOR_VERSION=${CMAKE_MINOR_VERSION}") + if(${CMAKE_MINOR_VERSION} GREATER_EQUAL 24) + # set(CMAKE_CUDA_ARCHITECTURES all) + set(CMAKE_CUDA_ARCHITECTURES all-major) + # set(CMAKE_CUDA_ARCHITECTURES native) + else() + set(CMAKE_CUDA_ARCHITECTURES OFF) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -arch=all-major") + endif() + message("CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}") + + message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}") + if((DEFINED CMAKE_BUILD_TYPE) AND (CMAKE_BUILD_TYPE STREQUAL Debug)) + set(CMAKE_CUDA_FLAGS "-g -G ${CMAKE_CUDA_FLAGS}") + endif() + + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler ${OpenMP_CXX_FLAGS}") + message("CMAKE_CUDA_FLAGS=${CMAKE_CUDA_FLAGS}") +endif() + +set(BASE_LIB TexasSolver) +add_library(${BASE_LIB} ${SRC} ${CUDA_SRC}) +target_link_libraries(${BASE_LIB} PUBLIC OpenMP::OpenMP_CXX) +if(USE_CUDA) +# set_target_properties(${BASE_LIB} PROPERTIES CUDA_SEPARABLE_COMPILATION ON) +endif() + +if(BUILD_API) + set(API_TARGET api) + add_library(${API_TARGET} SHARED ${API_SRC}) + target_link_libraries(${API_TARGET} PUBLIC ${BASE_LIB}) +endif() + +set(EXE console_solver) +add_executable(${EXE} ${EXE_SRC}) +target_link_libraries(${EXE} PRIVATE ${BASE_LIB}) +if(MSVC) + target_link_options(${EXE} PUBLIC "/NODEFAULTLIB:LIBCMT") +endif() + +if(QT_GUI) + file(GLOB FORMS *.ui) + file(GLOB RESOURCES *.qrc) + file(GLOB TS_FILES *.ts) + file(GLOB QM_FILES *.qm) + # message("FORMS=${FORMS}") + # message("RESOURCES=${RESOURCES}") + # message("TS_FILES=${TS_FILES}") + # message("QM_FILES=${QM_FILES}") + + # set(CMAKE_AUTOMOC ON) + set(CMAKE_AUTORCC ON) + set(CMAKE_AUTOUIC ON) + + find_package(QT NAMES Qt6 Qt5 REQUIRED COMPONENTS Widgets) + set(QT_MAJOR Qt${QT_VERSION_MAJOR}) + message("QT_MAJOR=${QT_MAJOR}") + find_package(${QT_MAJOR} REQUIRED COMPONENTS Core Widgets LinguistTools) + + SET(ICON_NAME texassolver_logo) + if(WIN32) + file(GLOB ICON_FILE imgs/${ICON_NAME}.rc) + elseif(APPLE) + set(MACOSX_BUNDLE_ICON_FILE ${ICON_NAME}.icns) + file(GLOB ICON_FILE imgs/${ICON_NAME}.icns) + set_source_files_properties(${ICON_FILE} PROPERTIES MACOSX_PACKAGE_LOCATION "Resources") + endif() + # message("ICON_FILE=${ICON_FILE}") + + # set(CMAKE_AUTOMOC ON) doesn't work + # Q_OBJECT header + file(GLOB HEADERS *.h include/ui/*.h include/runtime/qsolverjob.h) + # message("HEADERS=${HEADERS}") + if(${QT_VERSION_MAJOR} GREATER_EQUAL 6) + qt6_wrap_cpp(GUI_SRC ${HEADERS}) + else() + qt5_wrap_cpp(GUI_SRC ${HEADERS}) + endif() + # message("GUI_SRC=${GUI_SRC}") + + set(GUI TexasSolverGui) + add_executable(${GUI} ${GUI_SRC} ${EXC_SRC} ${RESOURCES} ${FORMS} ${ICON_FILE}) + target_link_libraries(${GUI} PRIVATE ${QT_MAJOR}::Widgets ${QT_MAJOR}::Core ${BASE_LIB}) + set_target_properties(${GUI} PROPERTIES + WIN32_EXECUTABLE ON + MACOSX_BUNDLE ON + ) +endif() diff --git a/TexasSolverGui.pro b/TexasSolverGui.pro index c932ba6..1ca9f44 100644 --- a/TexasSolverGui.pro +++ b/TexasSolverGui.pro @@ -25,6 +25,7 @@ DEFINES += QT_DEPRECATED_WARNINGS TRANSLATIONS = lang_cn.ts\ lang_en.ts +CONFIG += c++2a macx: { QMAKE_CXXFLAGS += -Xpreprocessor -fopenmp -lomp -I/usr/local/include @@ -64,7 +65,8 @@ SOURCES += \ mainwindow.cpp \ src/Deck.cpp \ src/Card.cpp \ - src/console.cpp \ + src/card_format.cpp \ + # src/console.cpp \ src/GameTree.cpp \ src/library.cpp \ src/compairer/Dic5Compairer.cpp \ @@ -85,6 +87,7 @@ SOURCES += \ src/solver/CfrSolver.cpp \ src/solver/PCfrSolver.cpp \ src/solver/Solver.cpp \ + src/solver/slice_cfr.cpp \ src/tools/CommandLineTool.cpp \ src/tools/GameTreeBuildingSettings.cpp \ src/tools/lookup8.cpp \ @@ -93,6 +96,7 @@ SOURCES += \ src/tools/Rule.cpp \ src/tools/StreetSetting.cpp \ src/tools/utils.cpp \ + src/tools/logger.cpp \ src/trainable/CfrPlusTrainable.cpp \ src/trainable/DiscountedCfrTrainable.cpp \ src/trainable/DiscountedCfrTrainableHF.cpp \ @@ -129,6 +133,7 @@ HEADERS += \ include/trainable/DiscountedCfrTrainableSF.h \ mainwindow.h \ include/Card.h \ + include/card_format.h \ include/GameTree.h \ include/Deck.h \ include/json.hpp \ @@ -137,6 +142,7 @@ HEADERS += \ include/solver/Solver.h \ include/solver/BestResponse.h \ include/solver/CfrSolver.h \ + include/solver/slice_cfr.h \ include/tools/argparse.hpp \ include/tools/CommandLineTool.h \ include/tools/utils.h \ @@ -164,6 +170,7 @@ HEADERS += \ include/ranges/RiverCombs.h \ include/ranges/RiverRangeManager.h \ include/tools/tinyformat.h \ + include/tools/logger.h \ include/tools/qdebugstream.h \ include/runtime/qsolverjob.h \ qstextedit.h \ diff --git a/benchmark/texassolver.txt b/benchmark/texassolver.txt new file mode 100644 index 0000000..748dec1 --- /dev/null +++ b/benchmark/texassolver.txt @@ -0,0 +1,42 @@ +set_pot 10 +set_effective_stack 95 +set_board Qs,Jh,2h,4d +#set_range_oop AA,KK,QQ,JJ +#set_range_ip QQ:0.5,JJ:0.75 +#set_board Qs,Jh,2h +set_range_oop AA,KK,QQ,JJ,TT,99:0.75,88:0.75,77:0.5,66:0.25,55:0.25,AK,AQs,AQo:0.75,AJs,AJo:0.5,ATs:0.75,A6s:0.25,A5s:0.75,A4s:0.75,A3s:0.5,A2s:0.5,KQs,KQo:0.5,KJs,KTs:0.75,K5s:0.25,K4s:0.25,QJs:0.75,QTs:0.75,Q9s:0.5,JTs:0.75,J9s:0.75,J8s:0.75,T9s:0.75,T8s:0.75,T7s:0.75,98s:0.75,97s:0.75,96s:0.5,87s:0.75,86s:0.5,85s:0.5,76s:0.75,75s:0.5,65s:0.75,64s:0.5,54s:0.75,53s:0.5,43s:0.5 +set_range_ip QQ:0.5,JJ:0.75,TT,99,88,77,66,55,44,33,22,AKo:0.25,AQs,AQo:0.75,AJs,AJo:0.75,ATs,ATo:0.75,A9s,A8s,A7s,A6s,A5s,A4s,A3s,A2s,KQ,KJ,KTs,KTo:0.5,K9s,K8s,K7s,K6s,K5s,K4s:0.5,K3s:0.5,K2s:0.5,QJ,QTs,Q9s,Q8s,Q7s,JTs,JTo:0.5,J9s,J8s,T9s,T8s,T7s,98s,97s,96s,87s,86s,76s,75s,65s,64s,54s,53s,43s +set_bet_sizes oop,flop,bet,100 +set_bet_sizes oop,flop,raise,50 +set_bet_sizes oop,flop,allin +set_bet_sizes ip,flop,bet,100 +set_bet_sizes ip,flop,raise,50 +set_bet_sizes ip,flop,allin +set_bet_sizes oop,turn,bet,100 +set_bet_sizes oop,turn,donk,100 +set_bet_sizes oop,turn,raise,50 +set_bet_sizes oop,turn,allin +set_bet_sizes ip,turn,bet,100 +set_bet_sizes ip,turn,raise,50 +set_bet_sizes oop,river,bet,100 +set_bet_sizes oop,river,donk,100 +set_bet_sizes oop,river,raise,50 +set_bet_sizes oop,river,allin +set_bet_sizes ip,river,bet,100 +set_bet_sizes ip,river,raise,50 +set_bet_sizes ip,river,allin +set_allin_threshold 1.0 +set_raise_limit 2 +build_tree +estimate_tree_memory +set_thread_num 6 +#set_thread_num 81920 +set_slice_cfr 0 +set_accuracy 0.3 +set_max_iteration 1 +set_print_interval 10 +set_use_isomorphism 1 +start_solve +set_dump_rounds 1 +dump_result output_result2.json +#dump_setting output_setting.txt diff --git a/boardselector.cpp b/boardselector.cpp index 6250eb9..82f1db3 100644 --- a/boardselector.cpp +++ b/boardselector.cpp @@ -1,7 +1,7 @@ #include "boardselector.h" #include "ui_boardselector.h" -boardselector::boardselector(QTextEdit* boardEdit,QSolverJob::Mode mode,QWidget *parent) : +boardselector::boardselector(QTextEdit* boardEdit,PokerMode mode,QWidget *parent) : QDialog(parent), ui(new Ui::boardselector) { @@ -11,9 +11,9 @@ boardselector::boardselector(QTextEdit* boardEdit,QSolverJob::Mode mode,QWidget this->mode = mode; QString ranks; - if(mode == QSolverJob::Mode::HOLDEM){ + if(mode == PokerMode::HOLDEM){ ranks = "A,K,Q,J,T,9,8,7,6,5,4,3,2"; - }else if(mode == QSolverJob::Mode::SHORTDECK){ + }else if(mode == PokerMode::SHORTDECK){ ranks = "A,K,Q,J,T,9,8,7,6"; }else{ throw runtime_error("mode not found in range selector"); diff --git a/boardselector.h b/boardselector.h index 99d9a5c..bb1ee8d 100644 --- a/boardselector.h +++ b/boardselector.h @@ -18,7 +18,7 @@ class boardselector : public QDialog Q_OBJECT public: - explicit boardselector(QTextEdit* boardEdit,QSolverJob::Mode mode = QSolverJob::Mode::HOLDEM,QWidget *parent = 0); + explicit boardselector(QTextEdit* boardEdit,PokerMode mode = PokerMode::HOLDEM,QWidget *parent = 0); ~boardselector(); private slots: @@ -37,7 +37,7 @@ private slots: private: Ui::boardselector *ui; QTextEdit* boardEdit = NULL; - QSolverJob::Mode mode; + PokerMode mode; QStringList rank_list; BoardSelectorTableModel * boardSelectorTableModel = NULL; BoardSelectorTableDelegate * boardSelectorTableDelegate = NULL; diff --git a/imgs/texassolver_logo.rc b/imgs/texassolver_logo.rc new file mode 100644 index 0000000..e0626ca --- /dev/null +++ b/imgs/texassolver_logo.rc @@ -0,0 +1 @@ +IDI_ICON1 ICON "texassolver_logo.ico" \ No newline at end of file diff --git a/include/Card.h b/include/Card.h index 7b52585..a64cd8f 100644 --- a/include/Card.h +++ b/include/Card.h @@ -8,7 +8,7 @@ #include #include #include "include/tools/tinyformat.h" -#include +// #include using namespace std; class Card { @@ -20,17 +20,17 @@ class Card { Card(); explicit Card(string card,int card_number_in_deck); Card(string card); - string getCard(); + const string& getCard(); int getCardInt(); bool empty(); int getNumberInDeckInt(); static int card2int(Card card); - static int strCard2int(string card); + static int strCard2int(const string &card); static string intCard2Str(int card); static uint64_t boardCards2long(vector cards); static uint64_t boardCard2long(Card& card); static uint64_t boardCards2long(vector& cards); - static QString boardCards2html(vector& cards); + // static QString boardCards2html(vector& cards); static inline bool boardsHasIntercept(uint64_t board1,uint64_t board2){ return ((board1 & board2) != 0); }; @@ -43,9 +43,9 @@ class Card { static int rankToInt(char rank); static int suitToInt(char suit); static vector getSuits(); - string toString(); - string toFormattedString(); - QString toFormattedHtml(); + // string toString(); + // string toFormattedString(); + // QString toFormattedHtml(); }; #endif //TEXASSOLVER_CARD_H diff --git a/include/card_format.h b/include/card_format.h new file mode 100644 index 0000000..83aed98 --- /dev/null +++ b/include/card_format.h @@ -0,0 +1,11 @@ +#if !defined(_CARD_FORMAT_H_) +#define _CARD_FORMAT_H_ + +#include +#include "include/Card.h" + +string toFormattedString(Card &card); +QString toFormattedHtml(Card &card); +QString boardCards2html(vector& cards); + +#endif // _CARD_FORMAT_H_ diff --git a/include/library.h b/include/library.h index cb958f8..59f3d7b 100644 --- a/include/library.h +++ b/include/library.h @@ -78,7 +78,7 @@ Combinations::comb(unsigned long long n, unsigned long long k) { return r; } -vector string_split(string strin,char split); +vector string_split(string &strin, char split); uint64_t timeSinceEpochMillisec(); int random(int min, int max); float normalization_tanh(float stack,float ev,float ratio=7); diff --git a/include/ranges/RiverRangeManager.h b/include/ranges/RiverRangeManager.h index cf57d4f..7d29eee 100644 --- a/include/ranges/RiverRangeManager.h +++ b/include/ranges/RiverRangeManager.h @@ -18,6 +18,10 @@ class RiverRangeManager { RiverRangeManager(shared_ptr handEvaluator); const vector& getRiverCombos(int player, const vector& riverCombos, const vector& board); const vector& getRiverCombos(int player, const vector& riverCombos, uint64_t board_long); + void clear() { + p1RiverRanges.clear(); + p2RiverRanges.clear(); + } private: unordered_map> p1RiverRanges; unordered_map> p2RiverRanges; diff --git a/include/runtime/PokerSolver.h b/include/runtime/PokerSolver.h index 705919b..f67487f 100644 --- a/include/runtime/PokerSolver.h +++ b/include/runtime/PokerSolver.h @@ -12,14 +12,22 @@ #include "include/solver/CfrSolver.h" #include "include/solver/PCfrSolver.h" #include "include/library.h" -#include -#include +#include "include/solver/slice_cfr.h" +// #include +// #include using namespace std; +enum PokerMode { + HOLDEM, + SHORTDECK, + UNKNOWN +}; + class PokerSolver { public: - PokerSolver(); - PokerSolver(string ranks,string suits,string compairer_file,int compairer_file_lines,string compairer_file_bin); + PokerSolver() {} + PokerSolver(PokerMode mode, string &resource_dir); + PokerSolver(string &ranks, string &suits, string &compairer_file, int compairer_file_lines, string &compairer_file_bin); void load_game_tree(string game_tree_file); void build_game_tree( float oop_commit, @@ -33,28 +41,31 @@ class PokerSolver { float allin_threshold ); void train( - string p1_range, - string p2_range, - string boards, - string log_file, + string &p1_range, + string &p2_range, + string &boards, + // string &log_file, int iteration_number, int print_interval, - string algorithm, + string &algorithm, int warmup, float accuracy, bool use_isomorphism, int use_halffloats, - int threads + int threads, + int slice_cfr = 0 ); void stop(); - long long estimate_tree_memory(QString range1,QString range2,QString board); + long long estimate_tree_memory(string& p1_range, string& p2_range, string& board); vector player1Range; vector player2Range; - void dump_strategy(QString dump_file,int dump_rounds); + void dump_strategy(string &dump_file, int dump_rounds); shared_ptr get_game_tree(){return this->game_tree;}; Deck* get_deck(){return &this->deck;} shared_ptr get_solver(){return this->solver;} + Logger *logger = nullptr; private: + void init(string &ranks, string &suits, string &compairer_file, int compairer_file_lines, string &compairer_file_bin); shared_ptr compairer; Deck deck; shared_ptr game_tree; diff --git a/include/runtime/qsolverjob.h b/include/runtime/qsolverjob.h index 507bb05..c701ee6 100644 --- a/include/runtime/qsolverjob.h +++ b/include/runtime/qsolverjob.h @@ -16,11 +16,11 @@ class QSolverJob : public QThread private: QSTextEdit * textEdit; public: - enum Mode{ - HOLDEM, - SHORTDECK - }; - Mode mode = Mode::HOLDEM; + // enum Mode{ + // HOLDEM, + // SHORTDECK + // }; + PokerMode mode = PokerMode::HOLDEM; enum MissionType{ LOADING, @@ -31,6 +31,7 @@ class QSolverJob : public QThread MissionType current_mission = MissionType::LOADING; string resource_dir; PokerSolver ps_holdem,ps_shortdeck; + /* float oop_commit=5; float ip_commit=5; int current_round=1; @@ -50,7 +51,9 @@ class QSolverJob : public QThread int print_interval=10; int dump_rounds = 2; shared_ptr gtbs; - + */ + CommandLineTool *clt = nullptr; + Logger *logger = nullptr; PokerSolver* get_solver(); void run(); void loading(); @@ -58,8 +61,8 @@ class QSolverJob : public QThread void stop(); void saving(); void build_tree(); - long long estimate_tree_memory(QString range1,QString range2,QString board); + long long estimate_tree_memory(string &range1, string &range2, string &board); void setContext(QSTextEdit * textEdit); - QString savefile; + // QString savefile; }; #endif // QSOLVERJOB_H diff --git a/include/solver/BestResponse.h b/include/solver/BestResponse.h index a161a8a..efc86e3 100644 --- a/include/solver/BestResponse.h +++ b/include/solver/BestResponse.h @@ -16,6 +16,7 @@ #include #include #include +#include "include/tools/logger.h" using namespace std; @@ -47,7 +48,7 @@ class BestResponse { ); float printExploitability(shared_ptr root, int iterationCount, float initial_pot, uint64_t initialBoard); float getBestReponseEv(shared_ptr node, int player,vector> reach_probs, uint64_t initialBoard,int deal); - + Logger *logger = nullptr; private: vector bestResponse(shared_ptr node, int player, const vector>& reach_probs, uint64_t board,int deal); vector chanceBestReponse(shared_ptr node, int player, const vector>& reach_probs, uint64_t current_board,int deal); diff --git a/include/solver/PCfrSolver.h b/include/solver/PCfrSolver.h index cce4f68..7017afb 100644 --- a/include/solver/PCfrSolver.h +++ b/include/solver/PCfrSolver.h @@ -83,7 +83,7 @@ class PCfrSolver:public Solver { int iteration_number, bool debug, int print_interval, - string logfile, + /*string logfile*/Logger *logger, string trainer, Solver::MonteCarolAlg monteCarolAlg, int warmup, diff --git a/include/solver/Solver.h b/include/solver/Solver.h index d7f7271..2cfd9c5 100644 --- a/include/solver/Solver.h +++ b/include/solver/Solver.h @@ -5,7 +5,7 @@ #ifndef TEXASSOLVER_SOLVER_H #define TEXASSOLVER_SOLVER_H - +#include "include/tools/logger.h" #include class Solver { @@ -15,7 +15,7 @@ class Solver { PUBLIC }; Solver(); - Solver(shared_ptr tree); + Solver(shared_ptr tree, Logger *logger); shared_ptr getTree(); virtual void train() = 0; virtual void stop() = 0; @@ -23,6 +23,7 @@ class Solver { virtual vector>> get_strategy(shared_ptr node,vector cards) = 0; virtual vector>> get_evs(shared_ptr node,vector cards) = 0; shared_ptr tree; + Logger *logger = nullptr; }; diff --git a/include/solver/cuda_cfr.h b/include/solver/cuda_cfr.h new file mode 100644 index 0000000..6b3b010 --- /dev/null +++ b/include/solver/cuda_cfr.h @@ -0,0 +1,85 @@ +#ifndef _CUDA_CFR_H_ +#define _CUDA_CFR_H_ + +#include +#include +#include +#include +#include "include/nodes/GameTreeNode.h" +#include "include/solver/PCfrSolver.h" +#include +#include +#include "cuda_runtime.h" +#include "include/solver/slice_cfr.h" + +#define LANE_SIZE 32 + +struct CudaLeafNode { + float val = 0;// fold:player0的收益*随机概率,sd:胜者收益*随机概率 + int offset_prob_sum = 0; + int offset_p0 = 0; + int offset_p1 = 0; + float *data_p0 = nullptr; + float *data_p1 = nullptr; + int *info = nullptr; +}; +struct SDNode { + float val = 0;// 胜者收益*随机概率 + int offset_prob_sum = 0; + int offset_p0 = 0; + int offset_p1 = 0; + float *data_p0 = nullptr; + float *data_p1 = nullptr; + int *strength_data = nullptr; +}; + +class CudaCFR : public SliceCFR { +public: + CudaCFR( + shared_ptr tree, + vector &range1, + vector &range2, + vector &initial_board, + shared_ptr compairer, + Deck &deck, + int train_step, + int print_interval, + float accuracy, + int n_thread, + Logger *logger + ):SliceCFR(tree, range1, range2, initial_board, compairer, deck, train_step, print_interval, accuracy, n_thread, logger) {} + virtual ~CudaCFR(); + virtual size_t estimate_tree_size(); +protected: + int *dev_hand_card = nullptr; + int *dev_hand_card_ptr[N_PLAYER] {nullptr,nullptr}; + size_t *dev_hand_hash = nullptr; + size_t *dev_hand_hash_ptr[N_PLAYER] {nullptr,nullptr}; + int *dev_same_hand_idx = nullptr; + Node *dev_nodes = nullptr;// cuda内存地址 + CudaLeafNode *dev_leaf_node = nullptr;// cuda内存地址 + vector dev_data; + vector dev_strength; + float *dev_root_cfv = nullptr, *dev_prob_sum = nullptr; + virtual size_t init_memory(); + size_t init_player_node(); + size_t init_leaf_node(); + void set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &offset); + size_t init_strength_table(); + virtual void step(int iter, int player, int task); + virtual void leaf_cfv(int player); + int block_size(int size) {// ceil + return (size + LANE_SIZE - 1) / LANE_SIZE; + } + void clear_prob_sum(int len); + virtual void _reach_prob(int player, bool avg_strategy); + virtual void _rm(int player, bool avg_strategy); + virtual void clear_data(int player); + virtual void clear_root_cfv(); + virtual void post_process(); + virtual vector> get_avg_strategy(int idx); + virtual vector> get_ev(int idx); + virtual void cfv_to_ev(); +}; + +#endif // _CUDA_CFR_H_ diff --git a/include/solver/cuda_func.h b/include/solver/cuda_func.h new file mode 100644 index 0000000..b4c807d --- /dev/null +++ b/include/solver/cuda_func.h @@ -0,0 +1,23 @@ +#ifndef _CUDA_FUNC_H_ +#define _CUDA_FUNC_H_ + +#include "cuda_runtime.h" + +extern __host__ __device__ void print_data(int *arr, int n); +extern __host__ __device__ void print_data(size_t *arr, int n); +extern __host__ __device__ void print_data(float *arr, int n); +extern __global__ void print_data_kernel(int *arr, int n); +extern __global__ void print_data_kernel(size_t *arr, int n); +extern __global__ void print_data_kernel(float *arr, int n); +extern __global__ void clear_data_kernel(Node *node, int size, int n_hand); +extern __global__ void rm_avg_kernel(Node *node, int size, int n_hand); +extern __global__ void rm_kernel(Node *node, int size, int n_hand); +extern __global__ void reach_prob_avg_kernel(Node *node, int size, int n_hand); +extern __global__ void reach_prob_kernel(Node *node, int size, int n_hand); +extern __global__ void fold_cfv_kernel(int player, int size, CudaLeafNode *node, float *opp_prob_sum, int my_hand, int opp_hand, int *hand_card, size_t *hand_hash, int *same_hand_idx); +extern __global__ void sd_cfv_kernel(int player, int size, CudaLeafNode *node, float *opp_prob_sum, int my_hand, int opp_hand, int *my_card, int *opp_card, int n_card); +extern __global__ void best_cfv_kernel(Node *node, int size, int n_hand); +extern __global__ void cfv_kernel(Node *node, int size, int n_hand); +extern __global__ void discount_data_kernel(Node *node, int size, int n_hand, float pos_coef, float neg_coef, float coef); + +#endif // _CUDA_FUNC_H_ \ No newline at end of file diff --git a/include/solver/slice_cfr.h b/include/solver/slice_cfr.h new file mode 100644 index 0000000..231f09f --- /dev/null +++ b/include/solver/slice_cfr.h @@ -0,0 +1,205 @@ +#ifndef _SLICE_CFR_H_ +#define _SLICE_CFR_H_ + +#include +#include +#include +#include +#include "include/nodes/GameTreeNode.h" +#include "include/solver/PCfrSolver.h" +#include +#include +#include + +using std::vector; +using std::unordered_set; +using std::unordered_map; +using std::dynamic_pointer_cast; +using std::mutex; + +#define N_CARD 52 +#define N_PLAYER 2 +#define P0 0 +#define P1 1 +#define CHANCE_PLAYER N_PLAYER + +#define N_ROUND 4 +#define PREFLOP_ROUND 0 +#define FLOP_ROUND 1 +#define TURN_ROUND 2 +#define RIVER_ROUND 3 + +#define FOLD_TYPE 0 +#define SHOWDOWN_TYPE 1 +#define N_LEAF_TYPE 2 + +#define N_TYPE 5 + +#define two_card_hash(card1, card2) ((1LL<<(card1)) | (1LL<<(card2))) +#define tril_idx(r, c) (((r)*((r)-1)>>1)+(c)) // r>c>=0 + +#define get_size(n_act, n_hand) (((n_act) * 4 + 1) * (n_hand)) +#define cfv_offset(n_hand, act_idx) ((n_hand) * (act_idx)) +#define reach_prob_offset(n_act, n_hand, act_idx) (((n_act) * 3 + (act_idx)) * (n_hand)) +#define reach_prob_to_cfv(n_act, n_hand) ((n_act) * (n_hand) * 3) + +// 数组poss_card的索引[0,51]-->[1,52],8位二进制编码,最多选两个,占用高16位,低16位预留其他用途 +#define code_idx0(i) (((i)+1)<<24) +#define decode_idx0(x) (((x)>>24) - 1) +#define code_idx1(i) (((i)+1)<<16) +#define decode_idx1(x) ((((x)>>16)&0xff) - 1) + +#define EXP_TASK 0 +#define CFV_TASK 1 +#define CFR_TASK 2 + +struct Node { + int n_act = 0;// 动作数 + int parent_offset = -1;// 本节点对应的父节点数据reach_prob的偏移量 + float *parent_cfv = nullptr; + // mutex *mtx = nullptr; + float *data = nullptr;// cfv,regret_sum,strategy_sum,reach_prob,sum + float *opp_prob = nullptr; + size_t board = 0LL; +}; +struct LeafNode { + float *reach_prob[N_PLAYER] = {nullptr,nullptr}; + size_t info = 0; +}; +struct PreLeafNode { + PreLeafNode(float *cfv):cfv(cfv) {} + float *cfv = nullptr; + vector leaf_node_idx; +}; +struct DFSNode { + DFSNode(int player, int n_act, int parent_act, int info, int parent_dfs_idx, int parent_p0_act, int parent_p0_idx, int parent_p1_act, int parent_p1_idx) + :player(player), n_act(n_act), parent_act(parent_act), info(info), parent_dfs_idx(parent_dfs_idx) + , parent_p0_act(parent_p0_act), parent_p0_idx(parent_p0_idx), parent_p1_act(parent_p1_act), parent_p1_idx(parent_p1_idx) {} + int player = -1;// 活动玩家(叶子节点时为父节点玩家) + int n_act = 0;// 动作数 + int parent_act = -1;// 本节点对应的父节点动作索引 + int info = 0; + int parent_dfs_idx = -1; + int parent_p0_act = -1; + int parent_p0_idx = -1; + int parent_p1_act = -1; + int parent_p1_idx = -1; +}; + +struct StrengthData { + StrengthData(int size, const RiverCombs *p):size(size), data(p) {} + int size = 0; + const RiverCombs *data = nullptr; +}; + +class SliceCFR : public Solver { +public: + SliceCFR( + shared_ptr tree, + vector &range1, + vector &range2, + vector &initial_board, + shared_ptr compairer, + Deck &deck, + int train_step, + int print_interval, + float accuracy, + int n_thread, + Logger *logger + ); + virtual ~SliceCFR(); + virtual size_t estimate_tree_size(); + void train(); + vector exploitability(); + void stop(); + json dumps(bool with_status, int depth); + vector>> get_strategy(shared_ptr node, vector cards); + vector>> get_evs(shared_ptr node, vector cards); +protected: + atomic_bool stop_flag {false}; + bool init_succ = false; + int n_thread = 0; + int steps = 0, interval = 0, n_card = N_CARD, min_card = 0; + int init_round = 0; + int dfs_idx = 0;// 先序遍历 + unordered_map> node_idx; + int combination_num[N_ROUND-1] {1,N_CARD,N_CARD*N_CARD}; + size_t init_board = 0; + int hand_size[N_PLAYER]; + float norm = 1;// 根节点概率归一化系数 + float tol = 0.01;// exploitability容忍度 + float alpha = 1.5, beta = 0, gamma = 2; + float pos_coef = 0, neg_coef = 0, coef = 0; + RiverRangeManager rrm; + vector hand_card;// p0_card1,p0_card2,p1_card1,p1_card2,相对于min_card的偏移量 + int *hand_card_ptr[N_PLAYER] {nullptr,nullptr}; + vector hand_hash; + size_t *hand_hash_ptr[N_PLAYER] {nullptr,nullptr}; + vector poss_card; + int chance_branch[N_ROUND]; + int chance_den[N_ROUND]; + vector same_hand_idx; + int *same_hand_ptr[N_PLAYER] {nullptr,nullptr}; + vector> ranges; + vector dfs_node; + vector dfs_idx_map;// dfs遍历的每个节点在内存中的索引 + int node_cnt[N_TYPE]; + int n_leaf_node = 0; + int n_player_node = 0; + vector> leaf_node_dfs; + vector chance_node; + vector> ev; + float *ev_ptr = nullptr; + vector>> slice; + vector> slice_offset; + vector root_cfv, root_prob;// P0_cfv,P1_cfv,P0_prob,P1_prob + float *root_prob_ptr[N_PLAYER] {nullptr,nullptr}; + float *root_cfv_ptr[N_PLAYER] {nullptr,nullptr}; + // shared_ptr tree = nullptr; + Deck& deck; + void init(); + void init_hand_card(vector &range1, vector &range2); + void init_hand_card(vector &range, vector &cards, vector &prob, size_t board, vector &out); + void init_same_hand_idx(); + void init_min_card(); + virtual size_t init_memory(); + size_t init_player_node(); + size_t init_leaf_node(); + void set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &offset); + void normalization(); + size_t init_strength_table(); + void dfs(shared_ptr node, int parent_act=-1, int parent_dfs_idx=-1, int parent_p0_act=-1, int parent_p0_idx=-1, int parent_p1_act=-1, int parent_p1_idx=-1, int cnt0=0, int cnt1=0, int info=0); + void init_poss_card(Deck& deck, size_t board); + virtual void step(int iter, int player, int task); + virtual void leaf_cfv(int player); + void fold_cfv(int player, float *cfv, float *opp_reach, int my_hand, float val, size_t board); + void sd_cfv(int player, float *cfv, float *opp_reach, int my_hand, int opp_hand, float val, int idx); + void append_node_idx(int p_idx, int act_idx, int player, int cpu_node_idx); + vector> pre_leaf_node_map;// [dfs_idx,act_idx] + vector> pre_leaf_node;// [player,idx] + vector> root_child_idx; + vector leaf_node; + vector player_node; + Node *player_node_ptr = nullptr; + int sd_offset = 0; + // vector mtx; + // vector> mtx_map; + // int mtx_idx = N_PLAYER; + vector> strength; + size_t _estimate_tree_size(shared_ptr node); + virtual void _reach_prob(int player, bool avg_strategy); + virtual void _rm(int player, bool avg_strategy); + virtual void clear_data(int player); + virtual void clear_root_cfv(); + virtual void post_process() {} + json reConvertJson(const shared_ptr& node, int depth, int max_depth, int &idx, int info); + virtual vector> get_avg_strategy(int idx);// [n_hand,n_act] + virtual vector> get_ev(int idx);// [n_hand,n_act] + bool print_exploitability(int iter, Timer &timer); + virtual void cfv_to_ev(); + void cfv_to_ev(Node *node, int player); + void get_prob_sum(vector &prob_sum, float &sum, int player, float *reach_prob, size_t board); + void output_data(ActionNode *node, vector &cards, vector>> &out, bool ev); +}; + +#endif // _SLICE_CFR_H_ diff --git a/include/tools/CommandLineTool.h b/include/tools/CommandLineTool.h index 3b078fa..f68cde4 100644 --- a/include/tools/CommandLineTool.h +++ b/include/tools/CommandLineTool.h @@ -13,23 +13,62 @@ using namespace std; class CommandLineTool{ public: - CommandLineTool(string mode,string resource_dir); - void startWorking(); - void execFromFile(string input_file); - void processCommand(string input); -private: - enum Mode{ - HOLDEM, - SHORTDECK - }; - Mode mode; - string resource_dir; - PokerSolver ps; + CommandLineTool(); + void startWorking(PokerSolver *ps); + void execFromFile(const char *input_file, PokerSolver *ps); + void processCommand(string &input, PokerSolver *ps); + void dump_setting(const char *file); + void set_pot(float val) { + ip_commit = oop_commit = val / 2; + } + float get_pot() { + return ip_commit + oop_commit; + } + void set_effective_stack(float val) { + stack = val + ip_commit; + } + float get_effective_stack() { + return stack - ip_commit; + } + bool set_board(string &str); + bool set_bet_sizes(string &str, char delimiter = ',', vector *sizes = nullptr); + void build_tree(PokerSolver *ps) { + if(!ps) return; + ps->build_game_tree(oop_commit,ip_commit,current_round,raise_limit,small_blind,big_blind,stack,gtbs,allin_threshold); + } + void start_solve(PokerSolver *ps) { + if(!ps) return; + // cout << "<<>>" << endl; + logger->log("<<>>"); + ps->train( + range_ip, + range_oop, + board, + // "tmp_log.txt", + max_iteration, + print_interval, + algorithm, + -1, + accuracy, + use_isomorphism, + use_halffloats, + thread_num, + slice_cfr + ); + } +// private: + // enum Mode{ + // HOLDEM, + // SHORTDECK + // }; + // Mode mode; + // string resource_dir; + // PokerSolver ps; float oop_commit=5; float ip_commit=5; int current_round=1; int raise_limit=4; - int thread_number=1; + int thread_num=1; float small_blind=0.5; float big_blind=1; float stack=20 + 5; @@ -37,12 +76,27 @@ class CommandLineTool{ string range_ip; string range_oop; string board; - float accuracy; + string res_file; + string algorithm = "discounted_cfr"; + float accuracy = 0.1; int max_iteration=100; - int use_isomorphism=0; + bool use_isomorphism=0; + int use_halffloats=0; int print_interval=10; + int slice_cfr = 0; int dump_rounds = 1; - shared_ptr gtbs; + GameTreeBuildingSettings gtbs; + Logger *logger = nullptr; }; +void split(const string& s, char delimiter, vector& v); +void join(const vector &vec, char delimiter, string &out); + +template +string tostring(T val); +template +string tostring_oss(T val); + +int cmd_api(string &input_file, string &resource_dir, string &mode, string &log_file); + #endif //BINDSOLVER_COMMANDLINETOOL_H diff --git a/include/tools/GameTreeBuildingSettings.h b/include/tools/GameTreeBuildingSettings.h index c72da89..394d193 100644 --- a/include/tools/GameTreeBuildingSettings.h +++ b/include/tools/GameTreeBuildingSettings.h @@ -8,6 +8,7 @@ class GameTreeBuildingSettings { public: + GameTreeBuildingSettings() {} GameTreeBuildingSettings( StreetSetting flop_ip, StreetSetting turn_ip, @@ -21,7 +22,7 @@ class GameTreeBuildingSettings { StreetSetting flop_oop; StreetSetting turn_oop; StreetSetting river_oop; - StreetSetting& get_setting(string player,string round); + StreetSetting& get_setting(string &player, string &round); }; #endif //BINDSOLVER_GAMETREEBUILDINGSETTINGS_H diff --git a/include/tools/StreetSetting.h b/include/tools/StreetSetting.h index 6aa1e2b..9c4721c 100644 --- a/include/tools/StreetSetting.h +++ b/include/tools/StreetSetting.h @@ -12,8 +12,9 @@ class StreetSetting { vector bet_sizes; vector raise_sizes; vector donk_sizes; - bool allin; + bool allin = true; + StreetSetting() {} StreetSetting(vector bet_sizes, vector raise_sizes, vector donk_sizes, bool allin); }; diff --git a/include/tools/logger.h b/include/tools/logger.h new file mode 100644 index 0000000..b121be6 --- /dev/null +++ b/include/tools/logger.h @@ -0,0 +1,42 @@ +#if !defined(_LOGGER_H_) +#define _LOGGER_H_ + +#include +#include +#include +#include + +using std::string; + +void get_localtime(char *buf, size_t n, const char *format); +string get_localtime(); + +class Logger { +public: + Logger(bool cmd, const char *path, const char *mode = "w+", bool timestamp = false, bool new_line = true, int period = 10) + :cmd(cmd), timestamp(timestamp), new_line(new_line), period(period) { + if(path) { + file = fopen(path, mode); + if(!file) printf("failed to create file %s\n", path); + } + } + virtual ~Logger() { + if(file) { + fflush(file); + fclose(file); + } + } + virtual void log(const char *format, ...); + void flush() { + if(file) fflush(file); + } +protected: + void log_time(); + int step = 0, period = 10; + FILE *file = nullptr; + bool timestamp = false; + bool cmd = true; + bool new_line = true; +}; + +#endif // _LOGGER_H_ diff --git a/include/tools/utils.h b/include/tools/utils.h index c4f9ef9..6b5b70b 100644 --- a/include/tools/utils.h +++ b/include/tools/utils.h @@ -64,4 +64,29 @@ void exchange_color(vector& value,vector range,int rank1,int ra //throw runtime_error("exiting...here..."); } +#include +class Timer { +public: + Timer() { + reset(); + } + void reset() { + start = std::chrono::steady_clock::now(); + } + int64_t ms(bool reset=false) { + std::chrono::steady_clock::time_point curr = std::chrono::steady_clock::now(); + int64_t ans = std::chrono::duration_cast(curr-start).count(); + if(reset) start = curr; + return ans; + } + int64_t us(bool reset=false) { + std::chrono::steady_clock::time_point curr = std::chrono::steady_clock::now(); + int64_t ans = std::chrono::duration_cast(curr-start).count(); + if(reset) start = curr; + return ans; + } +private: + std::chrono::steady_clock::time_point start {}; +}; + #endif //BINDSOLVER_UTILS_H diff --git a/include/ui/tablestrategymodel.h b/include/ui/tablestrategymodel.h index ede362a..44f098f 100644 --- a/include/ui/tablestrategymodel.h +++ b/include/ui/tablestrategymodel.h @@ -12,6 +12,7 @@ #include "include/ui/treeitem.h" #include "include/nodes/GameActions.h" #include +#include "include/card_format.h" class TableStrategyModel : public QAbstractItemModel { diff --git a/mainwindow.cpp b/mainwindow.cpp index ec81aa4..3ed6a4b 100644 --- a/mainwindow.cpp +++ b/mainwindow.cpp @@ -18,17 +18,21 @@ MainWindow::MainWindow(QWidget *parent) : connect(this->ui->actionimport, &QAction::triggered, this, &MainWindow::on_actionimport_triggered); connect(this->ui->actionexport, &QAction::triggered, this, &MainWindow::on_actionexport_triggered); connect(this->ui->actionclear_all, &QAction::triggered, this, &MainWindow::on_actionclear_all_triggered); + logger = new QLogger((get_localtime() + ".txt").c_str(), "w+", false, 1); + clt.logger = logger; qSolverJob = new QSolverJob; + qSolverJob->clt = &clt; qSolverJob->setContext(this->getLogArea()); + qSolverJob->logger = logger; qSolverJob->current_mission = QSolverJob::MissionType::LOADING; qSolverJob->start(); this->setWindowTitle(tr("TexasSolver")); - // parameters tree view QStringList filters; filters << "*.txt"; qFileSystemModel = new QFileSystemModel(this); - QDir filedir = QDir::current().filePath("parameters"); + QDir filedir = QDir::current()/*.filePath("parameters")*/; + logger->log(filedir.absolutePath().toLocal8Bit()); qFileSystemModel->setRootPath(filedir.path()); #ifdef Q_OS_MAC filedir = QDir(""); @@ -63,6 +67,11 @@ MainWindow::MainWindow(QWidget *parent) : this->ui->oopRangeTableView->verticalHeader()->setMinimumSectionSize(1); this->ui->oopRangeTableView->horizontalHeader()->setMinimumSectionSize(1); this->ui->tabWidget->hide(); + + show_tree_params(); + show_solver_params(); + this->update(); + update_range_ui(); } QSTextEdit * MainWindow::get_logwindow(){ @@ -78,6 +87,7 @@ MainWindow::~MainWindow() delete oop_delegate; delete oop_model; delete ui; + if(logger) delete logger; } void MainWindow::on_actionjson_triggered(){ @@ -85,7 +95,11 @@ void MainWindow::on_actionjson_triggered(){ "output_strategy.json", tr("Json file (*.json)")); if(fileName.isNull())return; - this->qSolverJob->savefile = fileName; + QSettings setting("TexasSolver", "Setting"); + setting.beginGroup("solver"); + clt.dump_rounds = setting.value("dump_round").toInt(); + clt.res_file = (const char*)fileName.toLocal8Bit(); + // this->qSolverJob->savefile = fileName; qSolverJob->current_mission = QSolverJob::MissionType::SAVING; qSolverJob->start(); } @@ -100,10 +114,7 @@ QString getParams(QString input,QString key){ void MainWindow::on_actionclear_all_triggered(){ this->clear_all_params(); - this->ui->IpRangeTableView->update(); - this->ui->oopRangeTableView->update(); - this->ui->IpRangeTableView->setFocus(); - this->ui->oopRangeTableView->setFocus(); + update_range_ui(); } void MainWindow::clear_all_params(){ @@ -147,6 +158,7 @@ void MainWindow::import_from_file(QString fileName){ qDebug().noquote() << tr("File selection invalid."); return; } + /* QFile file(fileName); if(!file.open(QIODevice::ReadOnly)){ qDebug().noquote() << tr("File open failed."); @@ -261,7 +273,10 @@ void MainWindow::import_from_file(QString fileName){ this->ui->useIsoCheck->setChecked(false); } } - } + }*/ + clt.execFromFile(fileName.toLocal8Bit(), nullptr); + show_tree_params(); + show_solver_params(); this->update(); } @@ -272,6 +287,9 @@ void MainWindow::on_actionimport_triggered(){ QDir::currentPath(), tr("Text files (*.txt)")); this->import_from_file(fileName); + update_range_ui(); +} +void MainWindow::update_range_ui() { this->ui->IpRangeTableView->update(); this->ui->oopRangeTableView->update(); this->ui->IpRangeTableView->setFocus(); @@ -283,7 +301,8 @@ void MainWindow::on_actionexport_triggered(){ "parameters/output_parameters.txt", tr("Text file (*.txt)")); if(fileName.isNull())return; - QString output_text = ""; + clt.dump_setting(fileName.toLocal8Bit()); + /*QString output_text = ""; QTextStream out(&output_text); out << "set_pot " << this->ui->potText->text().trimmed(); out << "\n"; @@ -398,6 +417,7 @@ void MainWindow::on_actionexport_triggered(){ msgBox.setText(message); setlocale(LC_CTYPE, "C"); msgBox.exec(); + */ } void MainWindow::on_actionSettings_triggered(){ @@ -412,10 +432,15 @@ void MainWindow::on_ip_range(QString range_text){ void MainWindow::on_buttomSolve_clicked() { + /* qSolverJob->max_iteration = ui->iterationText->text().toInt(); qSolverJob->accuracy = ui->exploitabilityText->text().toFloat(); qSolverJob->print_interval = ui->logIntervalText->text().toInt(); qSolverJob->thread_number = ui->threadsText->text().toInt(); + */ + get_solver_params(); + show_solver_params(); + this->update(); qSolverJob->current_mission = QSolverJob::MissionType::SOLVING; qSolverJob->start(); } @@ -451,6 +476,7 @@ vector sizes_convert(QString input){ void MainWindow::on_buildTreeButtom_clicked() { + /* qSolverJob->range_ip = this->ui->ipRangeText->toPlainText().toStdString(); qSolverJob->range_oop = this->ui->oopRangeText->toPlainText().toStdString(); qSolverJob->board = this->ui->boardText->toPlainText().toStdString(); @@ -470,7 +496,7 @@ void MainWindow::on_buildTreeButtom_clicked() qSolverJob->ip_commit = this->ui->potText->text().toFloat() / 2; qSolverJob->oop_commit = this->ui->potText->text().toFloat() / 2; qSolverJob->stack = this->ui->effectiveStackText->text().toFloat() + qSolverJob->ip_commit; - qSolverJob->mode = this->ui->mode_box->currentIndex() == 0 ? QSolverJob::Mode::HOLDEM:QSolverJob::Mode::SHORTDECK; + qSolverJob->mode = this->ui->mode_box->currentIndex() == 0 ? PokerMode::HOLDEM : PokerMode::SHORTDECK; qSolverJob->allin_threshold = this->ui->allinThresholdText->text().toFloat(); qSolverJob->use_isomorphism = this->ui->useIsoCheck->isChecked(); qSolverJob->use_halffloats = this->ui->useHalfFloats_box->currentIndex(); @@ -508,10 +534,103 @@ void MainWindow::on_buildTreeButtom_clicked() ); qSolverJob->gtbs = make_shared(gbs_flop_ip,gbs_turn_ip,gbs_river_ip,gbs_flop_oop,gbs_turn_oop,gbs_river_oop); + */ + get_tree_params(); + show_tree_params(); + this->update(); + this->ui->IpRangeTableView->update(); + this->ui->oopRangeTableView->update(); qSolverJob->current_mission = QSolverJob::MissionType::BUILDTREE; qSolverJob->start(); } +void MainWindow::get_tree_params() { + qSolverJob->mode = this->ui->mode_box->currentIndex() == 0 ? PokerMode::HOLDEM : PokerMode::SHORTDECK; + string val = this->ui->boardText->toPlainText().toStdString(); + if(!clt.set_board(val)) { + qDebug().noquote() << tfm::format("Error : board %s not recognized", val).c_str(); + return; + } + clt.range_ip = this->ui->ipRangeText->toPlainText().toStdString(); + clt.range_oop = this->ui->oopRangeText->toPlainText().toStdString(); + clt.raise_limit = this->ui->raiseLimitText->text().toInt(); + clt.set_pot(this->ui->potText->text().toFloat()); + clt.set_effective_stack(this->ui->effectiveStackText->text().toFloat()); + clt.allin_threshold = this->ui->allinThresholdText->text().toFloat(); + + set_bet_sizes(ui->flop_ip_bet, &clt.gtbs.flop_ip.bet_sizes); + set_bet_sizes(ui->flop_ip_raise, &clt.gtbs.flop_ip.raise_sizes); + clt.gtbs.flop_ip.allin = ui->flop_ip_allin->isChecked(); + set_bet_sizes(ui->turn_ip_bet, &clt.gtbs.turn_ip.bet_sizes); + set_bet_sizes(ui->turn_ip_raise, &clt.gtbs.turn_ip.raise_sizes); + clt.gtbs.turn_ip.allin = ui->turn_ip_allin->isChecked(); + set_bet_sizes(ui->river_ip_bet, &clt.gtbs.river_ip.bet_sizes); + set_bet_sizes(ui->river_ip_raise, &clt.gtbs.river_ip.raise_sizes); + clt.gtbs.river_ip.allin = ui->river_ip_allin->isChecked(); + + set_bet_sizes(ui->flop_oop_bet, &clt.gtbs.flop_oop.bet_sizes); + set_bet_sizes(ui->flop_oop_raise, &clt.gtbs.flop_oop.raise_sizes); + clt.gtbs.flop_oop.allin = ui->flop_oop_allin->isChecked(); + set_bet_sizes(ui->turn_oop_bet, &clt.gtbs.turn_oop.bet_sizes); + set_bet_sizes(ui->turn_oop_raise, &clt.gtbs.turn_oop.raise_sizes); + set_bet_sizes(ui->turn_oop_donk, &clt.gtbs.turn_oop.donk_sizes); + clt.gtbs.turn_oop.allin = ui->turn_oop_allin->isChecked(); + set_bet_sizes(ui->river_oop_bet, &clt.gtbs.river_oop.bet_sizes); + set_bet_sizes(ui->river_oop_raise, &clt.gtbs.river_oop.raise_sizes); + set_bet_sizes(ui->river_oop_donk, &clt.gtbs.river_oop.donk_sizes); + clt.gtbs.river_oop.allin = ui->river_oop_allin->isChecked(); +} + +void MainWindow::get_solver_params() { + clt.use_isomorphism = this->ui->useIsoCheck->isChecked(); + clt.use_halffloats = this->ui->useHalfFloats_box->currentIndex(); + clt.max_iteration = ui->iterationText->text().toInt(); + clt.accuracy = ui->exploitabilityText->text().toFloat(); + clt.print_interval = ui->logIntervalText->text().toInt(); + clt.thread_num = ui->threadsText->text().toInt(); +} + +void MainWindow::show_tree_params() { + ui->boardText->setText(clt.board.c_str()); + ui->ipRangeText->setText(clt.range_ip.c_str()); + ui->oopRangeText->setText(clt.range_oop.c_str()); + ui->raiseLimitText->setText(QString::number(clt.raise_limit)); + ui->potText->setText(QString::number(clt.get_pot())); + ui->effectiveStackText->setText(QString::number(clt.get_effective_stack())); + ui->allinThresholdText->setText(QString::number(clt.allin_threshold)); + + show_bet_sizes(ui->flop_ip_bet, clt.gtbs.flop_ip.bet_sizes); + show_bet_sizes(ui->flop_ip_raise, clt.gtbs.flop_ip.raise_sizes); + ui->flop_ip_allin->setChecked(clt.gtbs.flop_ip.allin); + show_bet_sizes(ui->turn_ip_bet, clt.gtbs.turn_ip.bet_sizes); + show_bet_sizes(ui->turn_ip_raise, clt.gtbs.turn_ip.raise_sizes); + ui->turn_ip_allin->setChecked(clt.gtbs.turn_ip.allin); + show_bet_sizes(ui->river_ip_bet, clt.gtbs.river_ip.bet_sizes); + show_bet_sizes(ui->river_ip_raise, clt.gtbs.river_ip.raise_sizes); + ui->river_ip_allin->setChecked(clt.gtbs.river_ip.allin); + + show_bet_sizes(ui->flop_oop_bet, clt.gtbs.flop_oop.bet_sizes); + show_bet_sizes(ui->flop_oop_raise, clt.gtbs.flop_oop.raise_sizes); + ui->flop_oop_allin->setChecked(clt.gtbs.flop_oop.allin); + show_bet_sizes(ui->turn_oop_bet, clt.gtbs.turn_oop.bet_sizes); + show_bet_sizes(ui->turn_oop_raise, clt.gtbs.turn_oop.raise_sizes); + show_bet_sizes(ui->turn_oop_donk, clt.gtbs.turn_oop.donk_sizes); + ui->turn_oop_allin->setChecked(clt.gtbs.turn_oop.allin); + show_bet_sizes(ui->river_oop_bet, clt.gtbs.river_oop.bet_sizes); + show_bet_sizes(ui->river_oop_raise, clt.gtbs.river_oop.raise_sizes); + show_bet_sizes(ui->river_oop_donk, clt.gtbs.river_oop.donk_sizes); + ui->river_oop_allin->setChecked(clt.gtbs.river_oop.allin); +} + +void MainWindow::show_solver_params() { + ui->useIsoCheck->setChecked(clt.use_isomorphism); + ui->useHalfFloats_box->setCurrentIndex(clt.use_halffloats); + ui->iterationText->setText(QString::number(clt.max_iteration)); + ui->exploitabilityText->setText(QString::number(clt.accuracy)); + ui->logIntervalText->setText(QString::number(clt.print_interval)); + ui->threadsText->setText(QString::number(clt.thread_num)); +} + void MainWindow::on_copyButtom_clicked() { ui->flop_oop_bet->setText(ui->flop_ip_bet->text()); @@ -543,7 +662,7 @@ void MainWindow::on_stopSolvingButton_clicked() void MainWindow::on_ipRangeSelectButtom_clicked() { - QSolverJob::Mode mode = this->ui->mode_box->currentIndex() == 0 ? QSolverJob::Mode::HOLDEM:QSolverJob::Mode::SHORTDECK; + PokerMode mode = this->ui->mode_box->currentIndex() == 0 ? PokerMode::HOLDEM:PokerMode::SHORTDECK; this->rangeSelector = new RangeSelector(this->ui->ipRangeText,this,mode); rangeSelector->setAttribute(Qt::WA_DeleteOnClose); rangeSelector->show(); @@ -551,14 +670,15 @@ void MainWindow::on_ipRangeSelectButtom_clicked() void MainWindow::on_oopRangeSelectButtom_clicked() { - QSolverJob::Mode mode = this->ui->mode_box->currentIndex() == 0 ? QSolverJob::Mode::HOLDEM:QSolverJob::Mode::SHORTDECK; + PokerMode mode = this->ui->mode_box->currentIndex() == 0 ? PokerMode::HOLDEM:PokerMode::SHORTDECK; this->rangeSelector = new RangeSelector(this->ui->oopRangeText,this,mode); rangeSelector->setAttribute(Qt::WA_DeleteOnClose); rangeSelector->show(); } float iso_corh(QString board){ - vector board_str_arr = string_split(board.toStdString(),','); + string board_str = board.toStdString(); + vector board_str_arr = string_split(board_str, ','); vector initialBoard; for(string one_board_str:board_str_arr){ initialBoard.push_back(Card(one_board_str)); @@ -584,7 +704,7 @@ float iso_corh(QString board){ void MainWindow::on_estimateMemoryButtom_clicked() { - long long memory_float = this->qSolverJob->estimate_tree_memory(this->ui->ipRangeText->toPlainText(),this->ui->oopRangeText->toPlainText(),this->ui->boardText->toPlainText()); + long long memory_float = this->qSolverJob->estimate_tree_memory(clt.range_ip, clt.range_oop, clt.board); // float32 should take 4bytes float corh = 1; if(this->ui->useIsoCheck->isChecked()){ @@ -620,7 +740,7 @@ void MainWindow::on_estimateMemoryButtom_clicked() void MainWindow::on_selectBoardButton_clicked() { - QSolverJob::Mode mode = this->ui->mode_box->currentIndex() == 0 ? QSolverJob::Mode::HOLDEM:QSolverJob::Mode::SHORTDECK; + PokerMode mode = this->ui->mode_box->currentIndex() == 0 ? PokerMode::HOLDEM:PokerMode::SHORTDECK; this->boardSelector = new boardselector(this->ui->boardText,mode,this); boardSelector->setAttribute(Qt::WA_DeleteOnClose); boardSelector->show(); @@ -636,10 +756,7 @@ void MainWindow::item_clicked(const QModelIndex& index){ QFileInfo fileinfo = QFileInfo(this->qFileSystemModel->filePath(index)); if(fileinfo.suffix() == "txt"){ this->import_from_file(this->qFileSystemModel->filePath(index)); - this->ui->IpRangeTableView->update(); - this->ui->oopRangeTableView->update(); - this->ui->IpRangeTableView->setFocus(); - this->ui->oopRangeTableView->setFocus(); + update_range_ui(); } } } diff --git a/mainwindow.h b/mainwindow.h index f6ee1fd..9eb56a0 100644 --- a/mainwindow.h +++ b/mainwindow.h @@ -13,6 +13,32 @@ #include "settingeditor.h" #include "include/ui/rangeselectortablemodel.h" #include "include/ui/rangeselectortabledelegate.h" +#include + +class QLogger : public Logger { +public: + QLogger(const char *path, const char *mode = "w+", bool timestamp = false, int period = 10):Logger(false, path, mode, timestamp, true, period) {} + virtual void log(const char *format, ...) { + if(timestamp) log_time(); + va_list args; + va_start(args, format); + if(file) { + vfprintf(file, format, args); + if((++step) == period) { + step = 0; + fflush(file); + } + if(new_line) fprintf(file, "\n"); +#ifdef __GNUC__ + va_end(args); + va_start(args, format); +#endif + } + // qDebug().noquote() << QString::vasprintf(QObject::tr(format).toLocal8Bit(), args); + qDebug().noquote() << QString::vasprintf(QObject::tr(format).toStdString().c_str(), args); + va_end(args); + } +}; namespace Ui { class MainWindow; @@ -60,8 +86,24 @@ private slots: private: void clear_all_params(); + void get_tree_params(); + void get_solver_params(); + void show_tree_params(); + void show_solver_params(); + void set_bet_sizes(QLineEdit *edit, vector *sizes) { + string s = edit->text().toStdString(); + clt.set_bet_sizes(s, ' ', sizes); + } + void show_bet_sizes(QLineEdit *edit, vector &sizes) { + string s; + join(sizes, ' ', s); + edit->setText(s.c_str()); + } + void update_range_ui(); Ui::MainWindow *ui = NULL; QSolverJob* qSolverJob = NULL; + CommandLineTool clt; + Logger *logger = nullptr; QFileSystemModel * qFileSystemModel = NULL; StrategyExplorer* strategyExplorer = NULL; RangeSelector* rangeSelector = NULL; diff --git a/rangeselector.cpp b/rangeselector.cpp index 8e81a72..24db7c2 100644 --- a/rangeselector.cpp +++ b/rangeselector.cpp @@ -1,15 +1,15 @@ #include "rangeselector.h" #include "ui_rangeselector.h" -RangeSelector::RangeSelector(QTextEdit* rangeEdit,QWidget *parent,QSolverJob::Mode mode) : +RangeSelector::RangeSelector(QTextEdit* rangeEdit,QWidget *parent,PokerMode mode) : QDialog(parent), ui(new Ui::RangeSelector) { QString ranks; - if(mode == QSolverJob::Mode::HOLDEM){ + if(mode == PokerMode::HOLDEM){ ranks = "A,K,Q,J,T,9,8,7,6,5,4,3,2"; - }else if(mode == QSolverJob::Mode::SHORTDECK){ + }else if(mode == PokerMode::SHORTDECK){ ranks = "A,K,Q,J,T,9,8,7,6"; }else{ throw runtime_error("mode not found in range selector"); diff --git a/rangeselector.h b/rangeselector.h index 447e370..40438ad 100644 --- a/rangeselector.h +++ b/rangeselector.h @@ -24,14 +24,14 @@ class RangeSelector : public QDialog Q_OBJECT public: - explicit RangeSelector(QTextEdit* rangeEdit,QWidget *parent = 0,QSolverJob::Mode mode = QSolverJob::Mode::HOLDEM); + explicit RangeSelector(QTextEdit* rangeEdit,QWidget *parent = 0,PokerMode mode = PokerMode::HOLDEM); ~RangeSelector(); signals: void confirm_text(QString content); private: int max_val = 1000; float range_num = 1; - QSolverJob::Mode mode; + PokerMode mode; Ui::RangeSelector *ui; QStringList rank_list; RangeSelectorTableModel * rangeSelectorTableModel = NULL; diff --git a/src/Card.cpp b/src/Card.cpp index ca91ac1..f06c089 100644 --- a/src/Card.cpp +++ b/src/Card.cpp @@ -22,10 +22,11 @@ bool Card::empty(){ else return false; } -string Card::getCard() { +const string& Card::getCard() { return this->card; } +// rank * 4 + suit,[13,4] int Card::getCardInt() { return this->card_int; } @@ -39,7 +40,8 @@ int Card::card2int(Card card) { return strCard2int(card.getCard()); } -int Card::strCard2int(string card) { +// rank * 4 + suit,[13,4] +int Card::strCard2int(const string &card) { char rank = card.at(0); char suit = card.at(1); if(card.length() != 2){ @@ -74,14 +76,14 @@ uint64_t Card::boardCards2long(vector& cards){ return Card::boardInts2long(board_int); } -QString Card::boardCards2html(vector& cards){ +/*QString Card::boardCards2html(vector& cards){ QString ret_html = ""; for(auto one_card:cards){ if(one_card.empty())continue; ret_html += one_card.toFormattedHtml(); } return ret_html; -} +}*/ uint64_t Card::boardInt2long(int board){ // 这里hard code了一副扑克牌是52张 @@ -217,7 +219,7 @@ vector Card::getSuits(){ return {"c","d","h","s"}; } -string Card::toString() { +/*string Card::toString() { return this->card; } @@ -241,4 +243,4 @@ QString Card::toFormattedHtml() { else if(qString.contains("s")) qString = qString.replace("s", QString::fromLocal8Bit("♠<\/span>")); return qString; -} +}*/ diff --git a/src/api.cpp b/src/api.cpp index 5c9879c..1d28a08 100644 --- a/src/api.cpp +++ b/src/api.cpp @@ -12,20 +12,11 @@ #include EXPORT -int api(const char * input_file, const char * resource_dir = "./resources", const char * mode = "holdem") { +int api(const char * input_file, const char * resource_dir = "./resources", const char * mode = "holdem", const char *log_file = "") { string input_file_ = input_file; string resource_dir_ = resource_dir; string mode_ = mode; + string log_file_ = log_file; - if(mode_ != "holdem" && mode_ != "shortdeck") - throw runtime_error(tfm::format("mode %s error, not in ['holdem','shortdeck']", mode_)); - - if(input_file_.empty()) { - CommandLineTool clt = CommandLineTool(mode_, resource_dir_); - clt.startWorking(); - }else{ - cout << "EXEC FROM FILE" << endl; - CommandLineTool clt = CommandLineTool(mode_, resource_dir_); - clt.execFromFile(input_file_); - } + return cmd_api(input_file_, resource_dir_, mode_, log_file_); } \ No newline at end of file diff --git a/src/card_format.cpp b/src/card_format.cpp new file mode 100644 index 0000000..0a908e9 --- /dev/null +++ b/src/card_format.cpp @@ -0,0 +1,32 @@ +#include "include/card_format.h" + +string toFormattedString(Card &card) { + QString qString = QString::fromStdString(card.getCard()); + qString = qString.replace("c", "♣️"); + qString = qString.replace("d", "♦️"); + qString = qString.replace("h", "♥️"); + qString = qString.replace("s", "♠️"); + return qString.toStdString(); +} + +QString toFormattedHtml(Card &card) { + QString qString = QString::fromStdString(card.getCard()); + if(qString.contains("c")) + qString = qString.replace("c", QString::fromLocal8Bit("♣<\/span>")); + else if(qString.contains("d")) + qString = qString.replace("d", QString::fromLocal8Bit("♦<\/span>")); + else if(qString.contains("h")) + qString = qString.replace("h", QString::fromLocal8Bit("♥<\/span>")); + else if(qString.contains("s")) + qString = qString.replace("s", QString::fromLocal8Bit("♠<\/span>")); + return qString; +} + +QString boardCards2html(vector& cards){ + QString ret_html = ""; + for(auto one_card:cards){ + if(one_card.empty())continue; + ret_html += toFormattedHtml(one_card); + } + return ret_html; +} diff --git a/src/compairer/Dic5Compairer.cpp b/src/compairer/Dic5Compairer.cpp index 8410ea8..d6df6aa 100644 --- a/src/compairer/Dic5Compairer.cpp +++ b/src/compairer/Dic5Compairer.cpp @@ -5,11 +5,13 @@ #include "include/compairer/Dic5Compairer.h" #include -#include -#include -#include +// #include +// #include +// #include #include "time.h" +#ifndef _MSC_VER #include "unistd.h" +#endif #define SUIT_0_MASK 0x1111111111111 #define SUIT_1_MASK 0x2222222222222 @@ -51,14 +53,14 @@ void FiveCardsStrength::convert(unordered_map& strength_map) { } } bool FiveCardsStrength::load(const char* file_path) { - //ifstream file(file_path, ios::binary); - /*if (!file) { + ifstream file(file_path, ios::binary); + if (!file.is_open()) { file.close(); - return false; - }*/ + /*return false; + } QFile file(QString::fromStdString(file_path)); - if (!file.open(QIODevice::ReadOnly)){ + if (!file.open(QIODevice::ReadOnly)){*/ throw runtime_error("unable to load compairer file"); } flush_map.clear(); other_map.clear(); @@ -88,7 +90,7 @@ bool FiveCardsStrength::save(const char* file_path) { //qDebug() << "b"; //file_path = "/Users/bytedance/Desktop/card5_dic_zipped_shortdeck.bin"; ofstream file(file_path, ios::binary); - if (!file) { + if (!file.is_open()) { file.close(); return false; } @@ -136,17 +138,22 @@ bool FiveCardsStrength::check(unordered_map& strength_map) { Dic5Compairer::Dic5Compairer(string dic_dir,int lines,string dic_dir_bin):Compairer(std::move(dic_dir),lines){ if(fcs.load(dic_dir_bin.c_str())) return; - QFile infile(QString::fromStdString(this->dic_dir)); + std::ifstream infile(this->dic_dir); + if(!infile.is_open()) { + throw runtime_error("unable to load compairer file"); + } + /*QFile infile(QString::fromStdString(this->dic_dir)); if (!infile.open(QIODevice::ReadOnly)){ throw runtime_error("unable to load compairer file"); } QTextStream in(&infile); - //progressbar bar(lines / 1000); + //progressbar bar(lines / 1000);*/ + string line; int i = 0; - //while (std::getline(infile, line)) - while (!in.atEnd()) + while (std::getline(infile, line)) + // while (!in.atEnd()) { - string line = in.readLine().toStdString(); + // string line = in.readLine().toStdString(); vector linesp = string_split(line,','); if(linesp.size() != 2){ throw runtime_error(tfm::format("linesp not correct: %s",line)); diff --git a/src/console.cpp b/src/console.cpp index d83e852..a858ef0 100644 --- a/src/console.cpp +++ b/src/console.cpp @@ -4,31 +4,19 @@ #include "include/tools/CommandLineTool.h" #include "include/tools/argparse.hpp" -int main_backup(int argc,const char **argv) { +int main(int argc,const char **argv) { ArgumentParser parser; parser.addArgument("-i", "--input_file", 1, true); parser.addArgument("-r", "--resource_dir", 1, true); parser.addArgument("-m", "--mode", 1, true); + parser.addArgument("-l", "--log", 1, true); parser.parse(argc, argv); string input_file = parser.retrieve("input_file"); string resource_dir = parser.retrieve("resource_dir"); - if(resource_dir.empty()){ - resource_dir = "./resources"; - } string mode = parser.retrieve("mode"); - if(mode.empty()){mode = "holdem";} - if(mode != "holdem" && mode != "shortdeck") - throw runtime_error(tfm::format("mode %s error, not in ['holdem','shortdeck']",mode)); - - if(input_file.empty()) { - CommandLineTool clt = CommandLineTool(mode,resource_dir); - clt.startWorking(); - }else{ - cout << "EXEC FROM FILE" << endl; - CommandLineTool clt = CommandLineTool(mode,resource_dir); - clt.execFromFile(input_file); - } + string log_file = parser.retrieve("log"); + return cmd_api(input_file, resource_dir, mode, log_file); } diff --git a/src/library.cpp b/src/library.cpp index 03f2f10..1d16fb4 100644 --- a/src/library.cpp +++ b/src/library.cpp @@ -6,7 +6,7 @@ -vector string_split(string strin,char split){ +vector string_split(string &strin, char split){ vector retval; stringstream ss(strin); string token; diff --git a/src/nodes/GameActions.cpp b/src/nodes/GameActions.cpp index 920a9be..c361bfd 100644 --- a/src/nodes/GameActions.cpp +++ b/src/nodes/GameActions.cpp @@ -44,6 +44,8 @@ string GameActions::toString() { if(this->amount == -1) { return this->pokerActionToString(this->action); }else{ - return this->pokerActionToString(this->action) + " " + to_string(amount); + ostringstream oss; + oss << amount; + return this->pokerActionToString(this->action) + " " + oss.str(); } } diff --git a/src/runtime/PokerSolver.cpp b/src/runtime/PokerSolver.cpp index bff1619..2fe9ecd 100644 --- a/src/runtime/PokerSolver.cpp +++ b/src/runtime/PokerSolver.cpp @@ -3,12 +3,35 @@ // #include "include/runtime/PokerSolver.h" - -PokerSolver::PokerSolver() { - +#ifdef USE_CUDA +#include "include/solver/cuda_cfr.h" +#endif + +PokerSolver::PokerSolver(PokerMode mode, string &resource_dir) { + string suits = "c,d,h,s"; + string ranks; + string compairer_file, compairer_file_bin; + int lines; + if(mode == PokerMode::HOLDEM){ + ranks = "2,3,4,5,6,7,8,9,T,J,Q,K,A"; + compairer_file = resource_dir + "/compairer/card5_dic_sorted.txt"; + compairer_file_bin = resource_dir + "/compairer/card5_dic_zipped.bin"; + lines = 2598961; + }else if(mode == PokerMode::SHORTDECK){ + ranks = "6,7,8,9,T,J,Q,K,A"; + compairer_file = resource_dir + "/compairer/card5_dic_sorted_shortdeck.txt"; + compairer_file_bin = resource_dir + "/compairer/card5_dic_zipped_shortdeck.bin"; + lines = 376993; + }else{ + throw runtime_error(tfm::format("mode not recognized : ",mode)); + } + init(ranks, suits, compairer_file, lines, compairer_file_bin); } -PokerSolver::PokerSolver(string ranks, string suits, string compairer_file,int compairer_file_lines, string compairer_file_bin) { +PokerSolver::PokerSolver(string &ranks, string &suits, string &compairer_file, int compairer_file_lines, string &compairer_file_bin) { + init(ranks, suits, compairer_file, compairer_file_lines, compairer_file_bin); +} +void PokerSolver::init(string &ranks, string &suits, string &compairer_file, int compairer_file_lines, string &compairer_file_bin) { vector ranks_vector = string_split(ranks,','); vector suits_vector = string_split(suits,','); this->deck = Deck(ranks_vector,suits_vector); @@ -67,29 +90,31 @@ void PokerSolver::stop(){ } } -long long PokerSolver::estimate_tree_memory(QString range1,QString range2,QString board){ +long long PokerSolver::estimate_tree_memory(string &p1_range, string &p2_range, string &board){ if(this->game_tree == nullptr){ - qDebug().noquote() << QObject::tr("Please buld tree first."); + // qDebug().noquote() << QObject::tr("Please buld tree first."); + logger->log("Please buld tree first."); return 0; } else{ - string player1RangeStr = range1.toStdString(); - string player2RangeStr = range2.toStdString(); - - vector board_str_arr = string_split(board.toStdString(),','); + vector board_str_arr = string_split(board,','); vector initialBoard; for(string one_board_str:board_str_arr){ initialBoard.push_back(Card::strCard2int(one_board_str)); } - vector range1 = PrivateRangeConverter::rangeStr2Cards(player1RangeStr,initialBoard); - vector range2 = PrivateRangeConverter::rangeStr2Cards(player2RangeStr,initialBoard); + vector range1 = PrivateRangeConverter::rangeStr2Cards(p1_range,initialBoard); + vector range2 = PrivateRangeConverter::rangeStr2Cards(p2_range,initialBoard); return this->game_tree->estimate_tree_memory(this->deck.getCards().size() - initialBoard.size(),range1.size(),range2.size()); } } -void PokerSolver::train(string p1_range, string p2_range, string boards, string log_file, int iteration_number, - int print_interval, string algorithm,int warmup,float accuracy,bool use_isomorphism, int use_halffloats, int threads) { +void PokerSolver::train(string &p1_range, string &p2_range, string &boards, /*string &log_file,*/ int iteration_number, + int print_interval, string &algorithm,int warmup,float accuracy,bool use_isomorphism, int use_halffloats, int threads, int slice_cfr) { + if(game_tree == nullptr) { + logger->log("Please buld tree first."); + return; + } string player1RangeStr = p1_range; string player2RangeStr = p2_range; @@ -106,44 +131,65 @@ void PokerSolver::train(string p1_range, string p2_range, string boards, string this->player1Range = noDuplicateRange(range1,initial_board_long); this->player2Range = noDuplicateRange(range2,initial_board_long); - string logfile_name = log_file; - this->solver = make_shared( - game_tree - , range1 - , range2 - , initialBoard - , compairer - , deck - , iteration_number - , false - , print_interval - , logfile_name - , algorithm - , Solver::MonteCarolAlg::NONE - , warmup - , accuracy - , use_isomorphism - , use_halffloats - , threads - ); - this->solver->train(); + // string logfile_name = log_file; + if(solver) solver.reset();// 释放内存 + try { + if(slice_cfr == 1) { + solver = make_shared(game_tree, range1, range2, initialBoard, compairer, deck, iteration_number, print_interval, accuracy, threads, logger); + } + else if(slice_cfr == 2) { +#ifdef USE_CUDA + solver = make_shared(game_tree, range1, range2, initialBoard, compairer, deck, iteration_number, print_interval, accuracy, threads, logger); +#else + logger->log("please set USE_CUDA ON in CMakeLists.txt and rebuild project"); + return; +#endif + } + else { + solver = make_shared( + game_tree + , range1 + , range2 + , initialBoard + , compairer + , deck + , iteration_number + , false + , print_interval + , /*logfile_name*/logger + , algorithm + , Solver::MonteCarolAlg::NONE + , warmup + , accuracy + , use_isomorphism + , use_halffloats + , threads + ); + } + solver->train(); + } + catch(std::exception& e) { + std::cerr << e.what() << '\n'; + } } -void PokerSolver::dump_strategy(QString dump_file,int dump_rounds) { +void PokerSolver::dump_strategy(string &dump_file, int dump_rounds) { //locale &loc=locale::global(locale(locale(),"",LC_CTYPE)); setlocale(LC_ALL,""); json dump_json = this->solver->dumps(false,dump_rounds); //QFile ofile( QString::fromStdString(dump_file)); ofstream fileWriter; - fileWriter.open(dump_file.toLocal8Bit()); + fileWriter.open(dump_file); if(!fileWriter.fail()){ fileWriter << dump_json; fileWriter.flush(); fileWriter.close(); - qDebug().noquote() << QObject::tr("save success"); + // qDebug().noquote() << QObject::tr("save success"); + logger->log("save success"); }else{ - qDebug().noquote() << QObject::tr("save failed, file cannot be open"); + // qDebug().noquote() << QObject::tr("save failed, file cannot be open"); + logger->log("save failed, file cannot be open"); } setlocale(LC_CTYPE, "C"); } diff --git a/src/runtime/qsolverjob.cpp b/src/runtime/qsolverjob.cpp index 27a663f..fa37936 100644 --- a/src/runtime/qsolverjob.cpp +++ b/src/runtime/qsolverjob.cpp @@ -10,9 +10,9 @@ void QSolverJob:: setContext(QSTextEdit * textEdit){ } PokerSolver* QSolverJob::get_solver(){ - if(this->mode == Mode::HOLDEM){ + if(this->mode == PokerMode::HOLDEM){ return &this->ps_holdem; - }else if(this->mode == Mode::SHORTDECK){ + }else if(this->mode == PokerMode::SHORTDECK){ return &this->ps_shortdeck; }else throw runtime_error("unknown mode in get_solver"); } @@ -35,12 +35,14 @@ void QSolverJob::run() } catch (const runtime_error& error) { - qDebug().noquote() << tr("Encountering error:");//.toStdString() << endl; - qDebug().noquote() << error.what() << "\n"; + // qDebug().noquote() << tr("Encountering error:");//.toStdString() << endl; + // qDebug().noquote() << error.what() << "\n"; + logger->log("Encountering error:\n%s\n", error.what()); } } void QSolverJob::loading(){ + /* string suits = "c,d,h,s"; string ranks; this->resource_dir = ":/resources"; @@ -63,33 +65,50 @@ void QSolverJob::loading(){ lines = 376993; this->ps_shortdeck = PokerSolver(ranks,suits,compairer_file,lines,compairer_file_bin); qDebug().noquote() << tr("Loading finished. Good to go.");//.toStdString() << endl; + */ + resource_dir = "./resources"; + // qDebug().noquote() << tr("Loading holdem compairing file");//.toStdString() << endl; + logger->log("Loading holdem compairing file"); + ps_holdem = PokerSolver(PokerMode::HOLDEM, resource_dir); + // qDebug().noquote() << tr("Loading shortdeck compairing file");//.toStdString() << endl; + logger->log("Loading shortdeck compairing file"); + ps_shortdeck = PokerSolver(PokerMode::SHORTDECK, resource_dir); + // qDebug().noquote() << tr("Loading finished. Good to go.");//.toStdString() << endl; + logger->log("Loading finished. Good to go."); + ps_holdem.logger = logger; + ps_shortdeck.logger = logger; } void QSolverJob::saving(){ - qDebug().noquote() << tr("Saving json file..");//.toStdString() << std::endl; - + // qDebug().noquote() << tr("Saving json file..");//.toStdString() << std::endl; + logger->log("Saving json file.."); + /* QSettings setting("TexasSolver", "Setting"); setting.beginGroup("solver"); this->dump_rounds = setting.value("dump_round").toInt(); - - qDebug().noquote() << tr("Dump round: ") << this->dump_rounds; - if(this->dump_rounds == 3){ - qDebug().noquote() << tr("This could be slow, or even blow your RAM, dump to river is not well optimized :("); + */ + // qDebug().noquote() << tr("Dump round: ") << clt->dump_rounds; + logger->log("Dump round: %d", clt->dump_rounds); + if(clt->dump_rounds == 3){ + // qDebug().noquote() << tr("This could be slow, or even blow your RAM, dump to river is not well optimized :("); + logger->log("This could be slow, or even blow your RAM, dump to river is not well optimized"); } - if(this->mode == Mode::HOLDEM){ - this->ps_holdem.dump_strategy(this->savefile,this->dump_rounds); - }else if(this->mode == Mode::SHORTDECK){ - this->ps_shortdeck.dump_strategy(this->savefile,this->dump_rounds); + if(this->mode == PokerMode::HOLDEM){ + this->ps_holdem.dump_strategy(clt->res_file, clt->dump_rounds); + }else if(this->mode == PokerMode::SHORTDECK){ + this->ps_shortdeck.dump_strategy(clt->res_file, clt->dump_rounds); } - qDebug().noquote() << tr("Saving done.");//.toStdString() << std::endl; + // qDebug().noquote() << tr("Saving done.");//.toStdString() << std::endl; + logger->log("Saving done."); } void QSolverJob::stop(){ - qDebug().noquote() << tr("Trying to stop solver."); - if(this->mode == Mode::HOLDEM){ + // qDebug().noquote() << tr("Trying to stop solver."); + logger->log("Trying to stop solver."); + if(this->mode == PokerMode::HOLDEM){ this->ps_holdem.stop(); - }else if(this->mode == Mode::SHORTDECK){ + }else if(this->mode == PokerMode::SHORTDECK){ this->ps_shortdeck.stop(); } } @@ -97,9 +116,12 @@ void QSolverJob::stop(){ void QSolverJob::solving(){ // TODO 为什么ui上多次求解会积累memory?哪里leak了? // TODO 为什么有时候会莫名闪退? - qDebug().noquote() << tr("Start Solving..");//.toStdString() << std::endl; - - if(this->mode == Mode::HOLDEM){ + // qDebug().noquote() << tr("Start Solving..");//.toStdString() << std::endl; + logger->log("Start Solving.."); + PokerSolver *ps = (mode == PokerMode::HOLDEM ? &ps_holdem : &ps_shortdeck); + clt->start_solve(ps); + /* + if(this->mode == PokerMode::HOLDEM){ this->ps_holdem.train( this->range_ip, this->range_oop, @@ -114,7 +136,7 @@ void QSolverJob::solving(){ this->use_halffloats, this->thread_number ); - }else if(this->mode == Mode::SHORTDECK){ + }else if(this->mode == PokerMode::SHORTDECK){ this->ps_shortdeck.train( this->range_ip, this->range_oop, @@ -130,25 +152,34 @@ void QSolverJob::solving(){ this->thread_number ); } - qDebug().noquote() << tr("Solving done.");//.toStdString() << std::endl; + */ + // qDebug().noquote() << tr("Solving done.");//.toStdString() << std::endl; + logger->log("Solving done."); } -long long QSolverJob::estimate_tree_memory(QString range1,QString range2,QString board){ - qDebug().noquote() << tr("Estimating tree memory..");//.toStdString() << endl; - if(this->mode == Mode::HOLDEM){ - return ps_holdem.estimate_tree_memory(range1,range2,board); - }else if(this->mode == Mode::SHORTDECK){ - return ps_shortdeck.estimate_tree_memory(range1,range2,board); +long long QSolverJob::estimate_tree_memory(string &range1, string &range2, string &board) { + // qDebug().noquote() << tr("Estimating tree memory..");//.toStdString() << endl; + logger->log("Estimating tree memory.."); + if(this->mode == PokerMode::HOLDEM){ + return ps_holdem.estimate_tree_memory(range1, range2, board); + }else if(this->mode == PokerMode::SHORTDECK){ + return ps_shortdeck.estimate_tree_memory(range1, range2, board); } return 0; } void QSolverJob::build_tree(){ - qDebug().noquote() << tr("building tree..");//.toStdString() << endl; - if(this->mode == Mode::HOLDEM){ + // qDebug().noquote() << tr("building tree..");//.toStdString() << endl; + logger->log("building tree.."); + PokerSolver *ps = (mode == PokerMode::HOLDEM ? &ps_holdem : &ps_shortdeck); + clt->build_tree(ps); + /* + if(this->mode == PokerMode::HOLDEM){ ps_holdem.build_game_tree(oop_commit,ip_commit,current_round,raise_limit,small_blind,big_blind,stack,*gtbs.get(),allin_threshold); - }else if(this->mode == Mode::SHORTDECK){ + }else if(this->mode == PokerMode::SHORTDECK){ ps_shortdeck.build_game_tree(oop_commit,ip_commit,current_round,raise_limit,small_blind,big_blind,stack,*gtbs.get(),allin_threshold); } - qDebug().noquote() << tr("build tree finished");//.toStdString() << endl; + */ + // qDebug().noquote() << tr("build tree finished");//.toStdString() << endl; + logger->log("build tree finished"); } diff --git a/src/solver/BestResponse.cpp b/src/solver/BestResponse.cpp index 90d174b..c5d4ac3 100644 --- a/src/solver/BestResponse.cpp +++ b/src/solver/BestResponse.cpp @@ -3,9 +3,9 @@ // #include "include/solver/BestResponse.h" -#include -#include -#include +// #include +// #include +// #include //#define DEBUG; BestResponse::BestResponse(vector> &private_combos, int player_number, @@ -40,7 +40,8 @@ float BestResponse::printExploitability(shared_ptr root, int itera if(this->reach_probs.empty()) this->reach_probs = vector> (this->player_number); - qDebug().noquote() << QString::fromStdString(tfm::format(QObject::tr("Iter: %s").toStdString().c_str(),iterationCount)); + // qDebug().noquote() << QString::fromStdString(tfm::format(QObject::tr("Iter: %s").toStdString().c_str(),iterationCount)); + logger->log("Iter: %d", iterationCount); float exploitible = 0; // 构造双方初始reach probs(按照手牌weights) for (int player_id = 0; player_id < this->player_number; player_id++) { @@ -54,10 +55,12 @@ float BestResponse::printExploitability(shared_ptr root, int itera for (int player_id = 0; player_id < this->player_number; player_id++) { float player_exploitability = getBestReponseEv(root, player_id, reach_probs, initialBoard, 0); exploitible += player_exploitability; - qDebug().noquote() << (QString::fromStdString(tfm::format(QObject::tr("player %s exploitability %s").toStdString().c_str(), player_id, player_exploitability))); + // qDebug().noquote() << (QString::fromStdString(tfm::format(QObject::tr("player %s exploitability %s").toStdString().c_str(), player_id, player_exploitability))); + logger->log("player %d exploitability %f", player_id, player_exploitability); } float total_exploitability = exploitible / this->player_number / initial_pot * 100; - qDebug().noquote() << QString::fromStdString(tfm::format(QObject::tr("Total exploitability %s precent").toStdString().c_str(), total_exploitability)); + // qDebug().noquote() << QString::fromStdString(tfm::format(QObject::tr("Total exploitability %s precent").toStdString().c_str(), total_exploitability)); + logger->log("Total exploitability %f precent", total_exploitability); return total_exploitability; } @@ -130,7 +133,7 @@ BestResponse::chanceBestReponse(shared_ptr node, int player,const ve vector> results(node->getCards().size()); #pragma omp parallel for - for(std::size_t card = 0;card < node->getCards().size();card ++) { + for(std::int64_t card = 0;card < node->getCards().size();card ++) { shared_ptr one_child = node->getChildren(); Card one_card = node->getCards()[card]; uint64_t card_long = Card::boardInt2long(one_card.getCardInt()); diff --git a/src/solver/PCfrSolver.cpp b/src/solver/PCfrSolver.cpp index 29395ed..a406a36 100644 --- a/src/solver/PCfrSolver.cpp +++ b/src/solver/PCfrSolver.cpp @@ -4,9 +4,9 @@ #include #include "include/solver/PCfrSolver.h" -#include -#include -#include +// #include +// #include +// #include //#define DEBUG; @@ -16,7 +16,7 @@ PCfrSolver::~PCfrSolver(){ PCfrSolver::PCfrSolver(shared_ptr tree, vector range1, vector range2, vector initial_board, shared_ptr compairer, Deck deck, int iteration_number, bool debug, - int print_interval, string logfile, string trainer, Solver::MonteCarolAlg monteCarolAlg,int warmup,float accuracy,bool use_isomorphism,int use_halffloats,int num_threads) :Solver(tree){ + int print_interval, /*string logfile*/Logger *logger, string trainer, Solver::MonteCarolAlg monteCarolAlg,int warmup,float accuracy,bool use_isomorphism,int use_halffloats,int num_threads):Solver(tree, logger){ this->initial_board = initial_board; this->initial_board_long = Card::boardInts2long(initial_board); this->logfile = logfile; @@ -53,7 +53,8 @@ PCfrSolver::PCfrSolver(shared_ptr tree, vector range1, v if(num_threads == -1){ num_threads = omp_get_num_procs(); } - qDebug().noquote() << QString::fromStdString(tfm::format(QObject::tr("Using %s threads").toStdString().c_str(),num_threads)); + // qDebug().noquote() << QString::fromStdString(tfm::format(QObject::tr("Using %s threads").toStdString().c_str(),num_threads)); + logger->log("Using %d threads", num_threads); this->num_threads = num_threads; this->distributing_task = false; omp_set_num_threads(this->num_threads); @@ -299,7 +300,7 @@ PCfrSolver::chanceUtility(int player, shared_ptr node, const vector< } #pragma omp parallel for schedule(static) - for(std::size_t valid_ind = 0;valid_ind < valid_cards.size();valid_ind++) { + for(std::int64_t valid_ind = 0;valid_ind < valid_cards.size();valid_ind++) { int card = valid_cards[valid_ind]; shared_ptr one_child = node->getChildren(); Card *one_card = const_cast(&(node->getCards()[card])); @@ -776,7 +777,7 @@ void PCfrSolver::train() { } BestResponse br = BestResponse(player_privates,this->player_number,this->pcm,this->rrm,this->deck,this->debug,this->color_iso_offset,this->split_round,this->num_threads,this->use_halffloats); - + br.logger = logger; br.printExploitability(tree->getRoot(), 0, tree->getRoot()->getPot(), initial_board_long); vector> reach_probs = this->getReachProbs(); @@ -802,9 +803,11 @@ void PCfrSolver::train() { if( (i % this->print_interval == 0 && i != 0 && i >= this->warmup) || this->nowstop) { endtime = timeSinceEpochMillisec(); long time_ms = endtime - begintime; - qDebug().noquote() << "-------------------"; + // qDebug().noquote() << "-------------------"; + logger->log("-------------------"); float expliotibility = br.printExploitability(tree->getRoot(), i + 1, tree->getRoot()->getPot(), initial_board_long); - qDebug().noquote() << QObject::tr("time used: ") << float(time_ms) / 1000 << QObject::tr(" second."); + // qDebug().noquote() << QObject::tr("time used: ") << float(time_ms) / 1000 << QObject::tr(" second."); + logger->log("time used: %f second.", float(time_ms) / 1000); if(!this->logfile.empty()){ json jo; jo["iteration"] = i; @@ -823,7 +826,8 @@ void PCfrSolver::train() { } } - qDebug().noquote() << QObject::tr("collecting statics"); + // qDebug().noquote() << QObject::tr("collecting statics"); + logger->log("collecting statics"); this->collecting_statics = true; for(int player_id = 0;player_id < this->player_number;player_id ++) { this->round_deal = vector{-1,-1,-1,-1}; @@ -838,8 +842,8 @@ void PCfrSolver::train() { } this->collecting_statics = false; this->statics_collected = true; - qDebug().noquote() << QObject::tr("statics collected"); - + // qDebug().noquote() << QObject::tr("statics collected"); + logger->log("statics collected"); if(!this->logfile.empty()) { fileWriter.flush(); fileWriter.close(); @@ -935,14 +939,14 @@ void PCfrSolver::reConvertJson(const shared_ptr& node,json& strate shared_ptr childerns = chanceNode->getChildren(); vector card_strs; for(Card card:cards) - card_strs.push_back(card.toString()); + card_strs.push_back(card.getCard()); json& dealcards = (*retval)["dealcards"]; for(std::size_t i = 0;i < cards.size();i ++){ vector> new_exchange_color_list(exchange_color_list); Card& one_card = const_cast(cards[i]); vector new_prefix(prefix); - new_prefix.push_back("Chance:" + one_card.toString()); + new_prefix.push_back("Chance:" + one_card.getCard()); std::size_t card = i; @@ -984,7 +988,7 @@ void PCfrSolver::reConvertJson(const shared_ptr& node,json& strate throw runtime_error("exchange color list shouldn't be exceed size 1 here"); } - string one_card_str = one_card.toString(); + string one_card_str = one_card.getCard(); if(exchange_color_list.size() == 1) { int rank1 = exchange_color_list[0][0]; int rank2 = exchange_color_list[0][1]; diff --git a/src/solver/Solver.cpp b/src/solver/Solver.cpp index 18424ad..5fd4677 100644 --- a/src/solver/Solver.cpp +++ b/src/solver/Solver.cpp @@ -8,7 +8,7 @@ Solver::Solver() { } -Solver::Solver(shared_ptr tree) { +Solver::Solver(shared_ptr tree, Logger *logger):logger(logger) { this->tree = tree; } diff --git a/src/solver/cuda_cfr.cu b/src/solver/cuda_cfr.cu new file mode 100644 index 0000000..dcd3a3b --- /dev/null +++ b/src/solver/cuda_cfr.cu @@ -0,0 +1,350 @@ +#include "solver/cuda_cfr.h" +#include "solver/cuda_func.h" +#include "ranges/RiverRangeManager.h" + +void cuda_error(cudaError_t error, const char *file, int line) { + if(error != cudaSuccess) { + printf("%s in %s at line %d\n", cudaGetErrorString(error), file, line); + exit(EXIT_FAILURE); + } +} + +#define CHECK_ERROR(error) (cuda_error(error, __FILE__, __LINE__)) + +template +void copy_to_device(T *dev, T *host, int n, bool print=false) { + if(!dev || !host || n <= 0) return; + size_t size = n * sizeof(T); + CHECK_ERROR(cudaMemcpy(dev, host, size, cudaMemcpyHostToDevice)); + if(!print) return; + print_data(host, n); + print_data_kernel<<<1, 1>>>(dev, n); + cudaDeviceSynchronize(); +} + +int max_malloc_len(int left, int right, int group_size = 1) { + int mid = 0, size = group_size * sizeof(float); + float *p = nullptr; + while(left < right) { + mid = (left + right + 1) >> 1;// 靠右 + if(cudaMalloc(&p, mid * size) == cudaSuccess) { + cudaFree(p); + left = mid; + } + else right = mid - 1; + } + return left; +} + +void CudaCFR::leaf_cfv(int player) { + Timer timer; + int opp = 1 - player, offset = player == P0 ? 0 : hand_size[P0]; + int my_hand = hand_size[player], opp_hand = hand_size[opp]; + int size = node_cnt[FOLD_TYPE]; + int block = block_size(size); + clear_prob_sum(size); + fold_cfv_kernel<<>>( + player, size, dev_leaf_node, dev_prob_sum, my_hand, opp_hand, + dev_hand_card_ptr[opp], dev_hand_hash_ptr[opp], dev_same_hand_idx+offset + ); + cudaDeviceSynchronize(); + // printf("fold_cfv:%zd ms\n", timer.ms(true)); + + size = node_cnt[SHOWDOWN_TYPE]; + block = block_size(size); + clear_prob_sum(size); + sd_cfv_kernel<<>>( + player, size, dev_leaf_node+sd_offset, dev_prob_sum, my_hand, opp_hand, + dev_hand_card_ptr[player], dev_hand_card_ptr[opp], n_card + ); + cudaDeviceSynchronize(); + // printf("sd_cfv:%zd ms\n", timer.ms()); +} + +CudaCFR::~CudaCFR() { + if(dev_root_cfv) CHECK_ERROR(cudaFree(dev_root_cfv)); + if(dev_hand_card) CHECK_ERROR(cudaFree(dev_hand_card)); + if(dev_hand_hash) CHECK_ERROR(cudaFree(dev_hand_hash)); + if(dev_nodes) CHECK_ERROR(cudaFree(dev_nodes)); + if(dev_leaf_node) CHECK_ERROR(cudaFree(dev_leaf_node)); + for(float *p : dev_data) { + if(p) CHECK_ERROR(cudaFree(p)); + } + if(dev_prob_sum) CHECK_ERROR(cudaFree(dev_prob_sum)); + for(int *p : dev_strength) { + if(p) CHECK_ERROR(cudaFree(p)); + } +} + +void CudaCFR::set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &offset) { + if(player == -1) player = node.player;// 向上连接同玩家节点 + int p_idx = node.parent_p0_idx, act_idx = node.parent_p0_act;// 向上连接P0 + if(player != P0) {// 向上连接P1 + p_idx = node.parent_p1_idx; + act_idx = node.parent_p1_act; + } + if(p_idx == -1) { + cfv = root_cfv_ptr[player]; + offset = root_prob_ptr[player] - root_cfv_ptr[player]; + } + else { + if(player != dfs_node[p_idx].player) throw runtime_error("player mismatch"); + cfv = dev_data[dfs_idx_map[p_idx]] + cfv_offset(hand_size[player], act_idx); + offset = reach_prob_to_cfv(dfs_node[p_idx].n_act, hand_size[player]); + } +} + +size_t CudaCFR::init_player_node() { + size_t total = 0, size = 0, node_size = n_player_node * sizeof(Node); + vector cpu_node(n_player_node);// 与cuda内存对应 + CHECK_ERROR(cudaMalloc(&dev_nodes, node_size)); + total += node_size; + dev_data = vector(n_player_node, nullptr); + dfs_idx_map = vector(dfs_idx, -1); + slice_offset = vector>(N_PLAYER); + int mem_idx = 0; + for(int i = 0; i < N_PLAYER; i++) {// 枚举player + for(vector &nodes : slice[i]) {// 枚举slice + slice_offset[i].push_back(mem_idx); + for(int idx : nodes) {// 枚举node + DFSNode &node = dfs_node[idx]; + Node &target = cpu_node[mem_idx];// cpu存储位置 + target.n_act = node.n_act; + set_cfv_and_offset(node, -1, target.parent_cfv, target.parent_offset); + size = get_size(node.n_act, hand_size[node.player]) * sizeof(float); + CHECK_ERROR(cudaMalloc(&target.data, size)); + if(target.data == nullptr) throw runtime_error("malloc error"); + total += size; + dev_data[mem_idx] = target.data; + dfs_idx_map[idx] = mem_idx++; + } + } + slice_offset[i].push_back(mem_idx); + } + CHECK_ERROR(cudaMemcpy(dev_nodes, cpu_node.data(), node_size, cudaMemcpyHostToDevice)); + return total; +} + +size_t CudaCFR::init_leaf_node() { + size_t node_size = n_leaf_node * sizeof(CudaLeafNode); + vector cpu_node(n_leaf_node);// 与cuda内存对应 + CHECK_ERROR(cudaMalloc(&dev_leaf_node, node_size)); + int mem_idx = 0; + for(int t = 0; t < N_LEAF_TYPE; t++) { + for(int i = 0; i < leaf_node_dfs[t].size(); i++) { + DFSNode &node = dfs_node[leaf_node_dfs[t][i]]; + CudaLeafNode &target = cpu_node[mem_idx++];// cpu存储位置 + target.val = ev[t][i]; + target.offset_prob_sum = i * n_card; + set_cfv_and_offset(node, P0, target.data_p0, target.offset_p0); + set_cfv_and_offset(node, P1, target.data_p1, target.offset_p1); + int j = decode_idx0(node.info), k = decode_idx1(node.info); + size_t info = init_board; + if(t == FOLD_TYPE) { + if(j != -1) info |= 1LL << poss_card[j]; + if(k != -1) info |= 1LL << poss_card[k]; + target.info = (int *)info; + } + else { + if(j == -1) info = 0; + else if(k == -1) info = j; + else info = tril_idx(j, k); + target.info = dev_strength[info]; + } + } + } + CHECK_ERROR(cudaMemcpy(dev_leaf_node, cpu_node.data(), node_size, cudaMemcpyHostToDevice)); + sd_offset = leaf_node_dfs[FOLD_TYPE].size(); + ev.clear(); + return node_size; +} + +size_t CudaCFR::init_memory() { + size_t total = 0; + int n = root_prob.size(); + root_cfv = vector(n, 0); + size_t size = ((n << 1)) * sizeof(float);// cfv + prob + CHECK_ERROR(cudaMalloc(&dev_root_cfv, size)); + total += size; + root_cfv_ptr[P0] = dev_root_cfv; + root_cfv_ptr[P1] = dev_root_cfv + hand_size[P0]; + root_prob_ptr[P0] = root_cfv_ptr[P0] + n; + root_prob_ptr[P1] = root_prob_ptr[P0] + hand_size[P0]; + clear_root_cfv(); + copy_to_device(root_prob_ptr[P0], root_prob.data(), n); + + vector temp_hand_card = hand_card; + vector temp_hand_hash = hand_hash; + // [P0,P1,P0] + temp_hand_card.insert(temp_hand_card.end(), hand_card.begin(), hand_card.begin()+(hand_size[P0]<<1)); + temp_hand_hash.insert(temp_hand_hash.end(), hand_hash.begin(), hand_hash.begin()+hand_size[P0]); + n = temp_hand_card.size(); + size = (n + same_hand_idx.size()) * sizeof(int);// [P0,P1,P0] + [P0,P1] + CHECK_ERROR(cudaMalloc(&dev_hand_card, size)); + total += size; + copy_to_device(dev_hand_card, temp_hand_card.data(), n); + dev_same_hand_idx = dev_hand_card + n; + copy_to_device(dev_same_hand_idx, same_hand_idx.data(), same_hand_idx.size()); + + n = temp_hand_hash.size(); + size = n * sizeof(size_t); + CHECK_ERROR(cudaMalloc(&dev_hand_hash, size)); + total += size; + copy_to_device(dev_hand_hash, temp_hand_hash.data(), n); + dev_hand_card_ptr[P0] = dev_hand_card; + dev_hand_card_ptr[P1] = dev_hand_card + (hand_size[P0]<<1); + dev_hand_hash_ptr[P0] = dev_hand_hash; + dev_hand_hash_ptr[P1] = dev_hand_hash + hand_size[P0]; + + total += init_player_node(); + total += init_strength_table(); + total += init_leaf_node(); + + // FOLD_TYPE,SHOWDOWN_TYPE,共用dev_prob_sum + int len = max(node_cnt[FOLD_TYPE], node_cnt[SHOWDOWN_TYPE]); + size = len * n_card * sizeof(float); + CHECK_ERROR(cudaMalloc(&dev_prob_sum, size)); + total += size; + return total; +} + +size_t CudaCFR::init_strength_table() { + SliceCFR::init_strength_table(); + int n = strength.size(); + size_t total = 0, size = 0; + dev_strength = vector(n, nullptr); + for(int i = 0; i < n; i++) { + const RiverCombs *p0_comb = strength[i][P0].data, *p1_comb = strength[i][P1].data; + int p0_size = strength[i][P0].size, p1_size = strength[i][P1].size, d = 0; + vector data(2+((p0_size+p1_size)<<1)); + data[d++] = 2 + (p0_size<<1); + data[d++] = data.size(); + for(int j = 0; j < p0_size; j++) { + data[d++] = p0_comb[j].rank; + data[d++] = p0_comb[j].reach_prob_index; + } + for(int j = 0; j < p1_size; j++) { + data[d++] = p1_comb[j].rank; + data[d++] = p1_comb[j].reach_prob_index; + } + size = data.size() * sizeof(int); + CHECK_ERROR(cudaMalloc(&dev_strength[i], size)); + total += size; + copy_to_device(dev_strength[i], data.data(), data.size()); + } + strength.clear(); + rrm.clear(); + return total; +} + +size_t CudaCFR::estimate_tree_size() { + for(int i = 0; i < N_TYPE; i++) node_cnt[i] = 0; + if(tree == nullptr) return 0; + size_t size = _estimate_tree_size(tree->getRoot()); + n_leaf_node = node_cnt[FOLD_TYPE] + node_cnt[SHOWDOWN_TYPE]; + n_player_node = node_cnt[N_LEAF_TYPE+P0] + node_cnt[N_LEAF_TYPE+P1]; + size *= sizeof(float); + size += n_leaf_node * sizeof(CudaLeafNode); + size += n_player_node * sizeof(Node); + size += max(node_cnt[FOLD_TYPE], node_cnt[SHOWDOWN_TYPE]) * n_card * sizeof(float); + return size; +} + +void CudaCFR::_reach_prob(int player, bool avg_strategy) { + vector& offset = slice_offset[player]; + int n = offset.size() - 1, size = 0, block = 0, n_hand = hand_size[player]; + for(int i = 0; i < n; i++) { + size = offset[i+1] - offset[i]; + block = block_size(size); + if(avg_strategy) reach_prob_avg_kernel<<>>(dev_nodes+offset[i], size, n_hand); + else reach_prob_kernel<<>>(dev_nodes+offset[i], size, n_hand); + cudaDeviceSynchronize(); + } +} + +void CudaCFR::_rm(int player, bool avg_strategy) { + int size = node_cnt[N_LEAF_TYPE + player]; + int block = block_size(size); + Node *node = dev_nodes + slice_offset[player][0]; + if(avg_strategy) rm_avg_kernel<<>>(node, size, hand_size[player]); + else rm_kernel<<>>(node, size, hand_size[player]); + cudaDeviceSynchronize(); +} + +void CudaCFR::clear_data(int player) { + int size = node_cnt[N_LEAF_TYPE + player]; + int block = block_size(size); + clear_data_kernel<<>>(dev_nodes+slice_offset[player][0], size, hand_size[player]); + cudaDeviceSynchronize(); +} + +void CudaCFR::clear_prob_sum(int len) { + CHECK_ERROR(cudaMemset(dev_prob_sum, 0, len * n_card * sizeof(float))); + cudaDeviceSynchronize(); +} + +void CudaCFR::clear_root_cfv() { + CHECK_ERROR(cudaMemset(dev_root_cfv, 0, root_cfv.size() * sizeof(float))); + cudaDeviceSynchronize(); +} + +void CudaCFR::step(int iter, int player, int task) { + Timer timer; + int opp = 1 - player, my_hand = hand_size[player], size = 0, block = 0; + _reach_prob(opp, task != CFR_TASK); + size_t t1 = timer.ms(true); + + leaf_cfv(player); + size_t t2 = timer.ms(true); + + if(task == CFR_TASK) { + size = n_player_node; + block = block_size(size); + discount_data_kernel<<>>(dev_nodes, size, my_hand, pos_coef, neg_coef, coef); + cudaDeviceSynchronize(); + } + size_t t3 = timer.ms(true); + vector& offset = slice_offset[player]; + for(int i = offset.size()-2; i >= 0; i--) { + size = offset[i+1] - offset[i]; + block = block_size(size); + if(task == EXP_TASK) best_cfv_kernel<<>>(dev_nodes+offset[i], size, my_hand); + else cfv_kernel<<>>(dev_nodes+offset[i], size, my_hand); + cudaDeviceSynchronize(); + } + size_t t4 = timer.ms(); + printf("%zd\t%zd\t%zd\t%zd\n", t1, t2, t3, t4); +} + +void CudaCFR::post_process() { + int n = root_cfv.size(); + CHECK_ERROR(cudaMemcpy(root_cfv.data(), dev_root_cfv, n * sizeof(float), cudaMemcpyDeviceToHost)); + // print_data(root_cfv.data(), n); + // print_data_kernel<<<1, 1>>>(dev_root_cfv, n); + // cudaDeviceSynchronize(); +} + +vector> CudaCFR::get_avg_strategy(int idx) { + DFSNode &node = dfs_node[idx]; + int n_hand = hand_size[node.player], n_act = node.n_act; + int size = n_act * n_hand, i = 0, h = 0, j = 0; + float *dev = dev_data[dfs_idx_map[idx]] + (size << 1), sum = 0, uni = 1.0 / n_act; + vector strategy_sum(size);// [n_act,n_hand] + CHECK_ERROR(cudaMemcpy(strategy_sum.data(), dev, size * sizeof(float), cudaMemcpyDeviceToHost)); + vector> strategy(n_hand, vector(n_act));// [n_hand,n_act] + for(h = 0; h < n_hand; h++) { + sum = 0; + for(i = h; i < size; i += n_hand) sum += strategy_sum[i]; + if(sum == 0) { + for(j = 0; j < n_act; j++) strategy[h][j] = uni; + } + else { + for(j = 0, i = h; j < n_act; j++, i += n_hand) strategy[h][j] = strategy_sum[i] / sum; + } + } + return strategy; +} +vector> CudaCFR::get_ev(int idx) { + return {}; +} +void CudaCFR::cfv_to_ev() {} diff --git a/src/solver/cuda_func.cu b/src/solver/cuda_func.cu new file mode 100644 index 0000000..5e7a52c --- /dev/null +++ b/src/solver/cuda_func.cu @@ -0,0 +1,277 @@ +#include "solver/cuda_cfr.h" +#include "solver/cuda_func.h" +#include "device_launch_parameters.h" + +__host__ __device__ void print_data(int *arr, int n) { + if(arr != nullptr && n > 0) { + printf("%d", arr[0]); + for(int i = 1; i < n; i++) printf(",%d", arr[i]); + } + printf("\n"); +} +__host__ __device__ void print_data(size_t *arr, int n) { + if(arr != nullptr && n > 0) { + printf("%llx", arr[0]); + for(int i = 1; i < n; i++) printf(",%llx", arr[i]); + } + printf("\n"); +} +__host__ __device__ void print_data(float *arr, int n) { + if(arr != nullptr && n > 0) { + printf("%.2f", arr[0]); + for(int i = 1; i < n; i++) printf(",%.2f", arr[i]); + } + printf("\n"); +} +__global__ void print_data_kernel(int *arr, int n) { + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i == 0) print_data(arr, n); +} +__global__ void print_data_kernel(size_t *arr, int n) { + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i == 0) print_data(arr, n); +} +__global__ void print_data_kernel(float *arr, int n) { + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i == 0) print_data(arr, n); +} + +__global__ void clear_data_kernel(Node *node, int size, int n_hand) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + size = get_size(node->n_act, n_hand); + float *data = node->data; + for(i = 0; i < size; i++) data[i] = 0; +} + +// 不同节点之间独立 +__global__ void rm_avg_kernel(Node *node, int size, int n_hand) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + size = node->n_act * n_hand; + int h = 0, sum_offset = size << 1; + float *data = node->data + (size << 1);// strategy_sum + float sum = 0; + for(h = 0; h < n_hand; h++) { + sum = 0; + for(i = h; i < size; i += n_hand) sum += data[i]; + data[sum_offset+h] = sum; + } +} +__global__ void rm_kernel(Node *node, int size, int n_hand) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + size = node->n_act * n_hand; + int h = 0, sum_offset = size * 3; + float *data = node->data + size;// regret_sum + float sum = 0; + for(h = 0; h < n_hand; h++) { + sum = 0; + for(i = h; i < size; i += n_hand) sum += max(0.0f, data[i]); + data[sum_offset+h] = sum; + } +} + +// 上层slice传递到本层slice +__global__ void reach_prob_avg_kernel(Node *node, int size, int n_hand) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + int n_act = node->n_act; + size = n_act * n_hand; + int h = 0, sum_offset = size << 1; + float *data = node->data + (size << 1);// strategy_sum + float *parent_prob = node->parent_cfv + node->parent_offset, temp = 0; + for(h = 0; h < n_hand; h++) { + if(data[sum_offset+h] == 0) {// 1/n_act + temp = parent_prob[h] / n_act; + for(i = h; i < size; i += n_hand) data[size+i] = temp; + } + else { + temp = parent_prob[h] / data[sum_offset+h]; + for(i = h; i < size; i += n_hand) data[size+i] = temp * data[i]; + } + } +} +__global__ void reach_prob_kernel(Node *node, int size, int n_hand) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + int n_act = node->n_act; + size = n_act * n_hand; + int h = 0, rp_offset = size << 1, sum_offset = rp_offset + size; + float *data = node->data + size;// regret_sum + float *parent_prob = node->parent_cfv + node->parent_offset, temp = 0; + for(h = 0; h < n_hand; h++) { + if(data[sum_offset+h] == 0) {// 1/n_act + temp = parent_prob[h] / n_act; + for(i = h; i < size; i += n_hand) data[rp_offset+i] = temp; + } + else { + temp = parent_prob[h] / data[sum_offset+h]; + for(i = h; i < size; i += n_hand) data[rp_offset+i] = temp * max(0.0f, data[i]); + } + } +} + +// 叶子节点向上层slice聚合,调用前需要清零上层slice的cfv +// same_hand_idx:player same_hand_idx +// hand_hash,hand_card:init opp [P0,P1,P0] +__global__ void fold_cfv_kernel(int player, int size, CudaLeafNode *node, float *opp_prob_sum, int my_hand, int opp_hand, int *hand_card, size_t *hand_hash, int *same_hand_idx) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + opp_prob_sum += node->offset_prob_sum; + size_t board = (size_t)node->info; + float *cfv = nullptr, *opp_reach = nullptr, val = node->val; + float prob_sum = 0, temp = 0; + if(player == P0) { + cfv = node->data_p0, opp_reach = node->data_p1 + node->offset_p1; + } + else { + cfv = node->data_p1, opp_reach = node->data_p0 + node->offset_p0; + val = -val; + } + for(i = 0; i < opp_hand; i++) { + if(hand_hash[i] & board) continue;// 对方手牌与公共牌冲突 + temp = opp_reach[i]; + opp_prob_sum[hand_card[i]] += temp;// card1 + opp_prob_sum[hand_card[i+opp_hand]] += temp;// card2 + prob_sum += temp; + } + hand_hash += opp_hand;// ptr of player + hand_card += (opp_hand << 1); + for(i = 0; i < my_hand; i++) { + if(hand_hash[i] & board) { + // cfv[i] = 0;// 与公共牌冲突,cfv为0 + continue; + } + temp = same_hand_idx[i] != -1 ? opp_reach[same_hand_idx[i]] : 0;// 重复计算的部分 + temp = (prob_sum - opp_prob_sum[hand_card[i]] - opp_prob_sum[hand_card[i+my_hand]] + temp) * val; + atomicAdd(cfv+i, temp); + } +} + +// showdown +__global__ void sd_cfv_kernel(int player, int size, CudaLeafNode *node, float *opp_prob_sum, int my_hand, int opp_hand, int *my_card, int *opp_card, int n_card) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return;// 总任务数 + node += i; + opp_prob_sum += node->offset_prob_sum; + float *cfv = nullptr, *opp_reach = nullptr; + float prob_sum = 0, temp = 0; + int j = 0, size_j = 0, h = 0, s = 0, *strength_data = node->info; + // strength_data:2+size0,2+size0+size1,sorted_data + // i,size for player + // j,size_j for opp + if(player == P0) { + i = 2, size_j = strength_data[1]; + size = j = strength_data[0]; + cfv = node->data_p0, opp_reach = node->data_p1 + node->offset_p1; + } + else { + j = 2, size = strength_data[1]; + size_j = i = strength_data[0]; + cfv = node->data_p1, opp_reach = node->data_p0 + node->offset_p0; + } + // strength_data += 2; + for(; i < size; i += 2) {// strength值变小,己方手牌变强 + s = strength_data[i]; + for(; j < size_j && strength_data[j] > s; j += 2) {// (胜过对方条件下)找到对方的最强手牌 + h = strength_data[j+1]; + temp = opp_reach[h]; + opp_prob_sum[opp_card[h]] += temp;// card1 + opp_prob_sum[opp_card[h+opp_hand]] += temp;// card2 + prob_sum += temp; + } + h = strength_data[i+1]; + temp = (prob_sum - opp_prob_sum[my_card[h]] - opp_prob_sum[my_card[h+my_hand]]) * node->val; + atomicAdd(cfv+h, temp); + } + prob_sum = 0; + for(h = 0; h < n_card; h++) opp_prob_sum[h] = 0; + i -= 2, j -= 2; + if(player == P0) { + size_j = size; + size = 2; + } + else { + size = size_j; + size_j = 2; + } + for(; i >= size; i -= 2) {// strength值变大,己方手牌变弱 + s = strength_data[i]; + for(; j >= size_j && strength_data[j] < s; j -= 2) {// (败给对方条件下)找到对方的最弱手牌 + h = strength_data[j+1]; + temp = opp_reach[h]; + opp_prob_sum[opp_card[h]] += temp;// card1 + opp_prob_sum[opp_card[h+opp_hand]] += temp;// card2 + prob_sum += temp; + } + h = strength_data[i+1]; + temp = (opp_prob_sum[my_card[h]] + opp_prob_sum[my_card[h+my_hand]] - prob_sum) * node->val; + atomicAdd(cfv+h, temp); + } +} + +// 本层slice向上层slice聚合,上层cfv需要先清零 +// 子节点cfv中选最大值 +__global__ void best_cfv_kernel(Node *node, int size, int n_hand) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + size = node->n_act * n_hand; + int h = 0; + float *parent_cfv = node->parent_cfv, *cfv = node->data, val = 0; + for(h = 0; h < n_hand; h++) { + val = cfv[h];// 第一个 + for(i = h+n_hand; i < size; i += n_hand) val = max(val, cfv[i]); + atomicAdd(parent_cfv+h, val); + } + for(i = 0; i < size; i++) cfv[i] = 0;// 清零cfv +} +// 子节点cfv加权求和 +__global__ void cfv_kernel(Node *node, int size, int n_hand) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + int n_act = node->n_act; + size = n_act * n_hand; + int h = 0, sum_offset = size << 2; + float *parent_cfv = node->parent_cfv, *cfv = node->data, val = 0; + float *regret_sum = cfv + size; + for(h = 0; h < n_hand; h++) { + val = 0; + if(cfv[sum_offset+h] == 0) { + for(i = h; i < size; i += n_hand) val += cfv[i]; + val /= n_act;// uniform strategy + } + else { + for(i = h; i < size; i += n_hand) { + val += cfv[i] * max(0.0f, regret_sum[i]); + } + val /= cfv[sum_offset+h]; + } + atomicAdd(parent_cfv+h, val); + for(i = h; i < size; i += n_hand) regret_sum[i] += cfv[i] - val;// 更新regret_sum + val = 0; + for(i = h; i < size; i += n_hand) val += max(0.0f, regret_sum[i]); + cfv[sum_offset+h] = val;// 求和 + } + for(i = 0; i < size; i++) cfv[i] = 0;// 清零cfv +} + +__global__ void discount_data_kernel(Node *node, int size, int n_hand, float pos_coef, float neg_coef, float coef) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if(i >= size) return; + node += i; + size = node->n_act * n_hand; + float *regret_sum = node->data + size, *strategy_sum = regret_sum + size; + for(i = 0; i < size; i++) { + regret_sum[i] *= regret_sum[i] > 0 ? pos_coef : neg_coef; + strategy_sum[i] = strategy_sum[i] * coef + strategy_sum[size+i]; + } +} diff --git a/src/solver/slice_cfr.cpp b/src/solver/slice_cfr.cpp new file mode 100644 index 0000000..99e7a1d --- /dev/null +++ b/src/solver/slice_cfr.cpp @@ -0,0 +1,1100 @@ +#include +#include "include/solver/slice_cfr.h" +#include "include/ranges/RiverRangeManager.h" + +using std::memory_order_relaxed; +using std::atomic_ref; + +void print_array(int *arr, int n) { + if(arr != nullptr && n > 0) { + printf("%d", arr[0]); + for(int i = 1; i < n; i++) printf(",%d", arr[i]); + } + printf("\n"); +} + +void test_parallel_for(int n_thread, int n = 100) { + vector cnt(n_thread); + #pragma omp parallel for + for(int i = 0; i < n; i++) { + cnt[omp_get_thread_num()]++; + } + print_array(cnt.data(), n_thread); +} + +inline bool cards_valid(size_t hash1, size_t hash2) { + return (hash1 & hash2) == 0; +} + +typedef void (*node_func)(Node *, int); + +void rm_avg(Node *node, int n_hand) { + int size = node->n_act * n_hand; + int i = 0, h = 0, sum_offset = size << 1; + float *data = node->data + (size << 1);// strategy_sum + float sum = 0; + for(h = 0; h < n_hand; h++) { + sum = 0; + for(i = h; i < size; i += n_hand) sum += data[i]; + data[sum_offset+h] = sum; + } +} +void rm(Node *node, int n_hand) { + int size = node->n_act * n_hand; + int i = 0, h = 0, sum_offset = size * 3; + float *data = node->data + size;// regret_sum + float sum = 0; + for(h = 0; h < n_hand; h++) { + sum = 0; + for(i = h; i < size; i += n_hand) sum += max(0.0f, data[i]); + data[sum_offset+h] = sum; + } +} +void reach_prob_avg(Node *node, int n_hand) { + int n_act = node->n_act, size = n_act * n_hand; + int i = 0, h = 0, sum_offset = size << 1; + float *data = node->data + (size << 1);// strategy_sum + float *parent_prob = node->parent_cfv + node->parent_offset, temp = 0; + for(h = 0; h < n_hand; h++) { + if(data[sum_offset+h] == 0) {// 1/n_act + temp = parent_prob[h] / n_act; + for(i = h; i < size; i += n_hand) data[size+i] = temp; + } + else { + temp = parent_prob[h] / data[sum_offset+h]; + for(i = h; i < size; i += n_hand) data[size+i] = temp * data[i]; + } + } +} +void reach_prob(Node *node, int n_hand) { + int n_act = node->n_act, size = n_act * n_hand; + int i = 0, h = 0, rp_offset = size << 1, sum_offset = rp_offset + size; + float *data = node->data + size;// regret_sum + float *parent_prob = node->parent_cfv + node->parent_offset, temp = 0; + for(h = 0; h < n_hand; h++) { + if(data[sum_offset+h] == 0) {// 1/n_act + temp = parent_prob[h] / n_act; + for(i = h; i < size; i += n_hand) data[rp_offset+i] = temp; + } + else { + temp = parent_prob[h] / data[sum_offset+h]; + for(i = h; i < size; i += n_hand) data[rp_offset+i] = temp * max(0.0f, data[i]); + } + } +} +// 子节点cfv取最大值 +void best_cfv_up(Node *node, int n_hand) { + int size = node->n_act * n_hand; + int i = 0, h = 0; + float *parent_cfv = node->parent_cfv, *cfv = node->data, val = 0; + // mutex *mtx = node->mtx; + for(h = 0; h < n_hand; h++) { + val = cfv[h];// 第一个 + for(i = h+n_hand; i < size; i += n_hand) val = max(val, cfv[i]); + // mtx->lock(); + // parent_cfv[h] += val;// 需要加锁 + // mtx->unlock(); + atomic_ref(parent_cfv[h]).fetch_add(val, memory_order_relaxed); + } + for(i = 0; i < size; i++) cfv[i] = 0;// 清零cfv +} +// 子节点cfv加权求和 +void cfv_up(Node *node, int n_hand) { + int n_act = node->n_act, size = n_act * n_hand; + int i = 0, h = 0, sum_offset = size << 2; + float *parent_cfv = node->parent_cfv, *cfv = node->data, val = 0; + float *regret_sum = cfv + size; + // mutex *mtx = node->mtx; + for(h = 0; h < n_hand; h++) { + val = 0; + if(cfv[sum_offset+h] == 0) { + for(i = h; i < size; i += n_hand) val += cfv[i]; + val /= n_act;// uniform strategy + } + else { + for(i = h; i < size; i += n_hand) { + val += cfv[i] * max(0.0f, regret_sum[i]); + } + val /= cfv[sum_offset+h]; + } + // cfv[sum_offset+h] = val; + // mtx->lock(); + // parent_cfv[h] += val;// 需要加锁 + // mtx->unlock(); + atomic_ref(parent_cfv[h]).fetch_add(val, memory_order_relaxed); + for(i = h; i < size; i += n_hand) regret_sum[i] += cfv[i] - val;// 更新regret_sum + val = 0; + for(i = h; i < size; i += n_hand) val += max(0.0f, regret_sum[i]); + cfv[sum_offset+h] = val;// 求和 + } + for(i = 0; i < size; i++) cfv[i] = 0;// 清零cfv +} +// 只计算cfv +void cfv_up_avg(Node *node, int n_hand) { + int n_act = node->n_act, size = n_act * n_hand; + int i = 0, h = 0, sum_offset = size << 2; + float *parent_cfv = node->parent_cfv, *cfv = node->data, val = 0; + float *strategy_sum = cfv + (size << 1); + // mutex *mtx = node->mtx; + for(h = 0; h < n_hand; h++) { + val = 0; + if(cfv[sum_offset+h] == 0) { + for(i = h; i < size; i += n_hand) val += cfv[i]; + val /= n_act;// uniform strategy + } + else { + for(i = h; i < size; i += n_hand) { + val += cfv[i] * strategy_sum[i]; + } + val /= cfv[sum_offset+h]; + } + // cfv[sum_offset+h] = val; + // mtx->lock(); + // parent_cfv[h] += val;// 需要加锁 + // mtx->unlock(); + atomic_ref(parent_cfv[h]).fetch_add(val, memory_order_relaxed); + // for(i = h; i < size; i += n_hand) regret_sum[i] += cfv[i] - val;// 更新regret_sum + // val = 0; + // for(i = h; i < size; i += n_hand) val += max(0.0f, regret_sum[i]); + // cfv[sum_offset+h] = val;// 求和 + } + // for(i = 0; i < size; i++) cfv[i] = 0;// 清零cfv +} +// 在cfv_up前执行 +void discount_data(Node *node, int n_hand, float pos_coef, float neg_coef, float coef) { + int size = node->n_act * n_hand, i = 0; + float *regret_sum = node->data + size, *strategy_sum = regret_sum + size; + for(i = 0; i < size; i++) { + regret_sum[i] *= regret_sum[i] > 0 ? pos_coef : neg_coef; + strategy_sum[i] = strategy_sum[i] * coef + strategy_sum[size+i]; + } +} + +void SliceCFR::cfv_to_ev() { + for(int i = 0; i < N_PLAYER; i++) { + vector& offset = slice_offset[i]; + #pragma omp parallel for + for(int j = offset[0]; j < offset.back(); j++) { + cfv_to_ev(player_node_ptr+j, i); + } + } +} +void SliceCFR::cfv_to_ev(Node *node, int player) { + float *opp_reach = node->opp_prob, *cfv = node->data; + size_t board = node->board; + vector opp_prob_sum(n_card, 0); + float prob_sum = 0, temp = 0; + get_prob_sum(opp_prob_sum, prob_sum, 1-player, opp_reach, board); + int n_hand = hand_size[player], size = node->n_act * n_hand, h = 0, i = 0; + int *same_hand = same_hand_ptr[player], *my_card = hand_card_ptr[player]; + size_t *my_hash = hand_hash_ptr[player]; + for(h = 0; h < n_hand; h++) { + temp = same_hand[h] != -1 ? opp_reach[same_hand[h]] : 0;// 重复计算的部分 + temp = prob_sum - opp_prob_sum[my_card[h]] - opp_prob_sum[my_card[h+n_hand]] + temp; + if((my_hash[h] & board) || temp == 0) { + for(i = h; i < size; i += n_hand) cfv[i] = 0; + } + else { + for(i = h; i < size; i += n_hand) cfv[i] /= temp; + } + } +} + +// #define TIME_LOG +#ifdef TIME_LOG +atomic_ullong fold_time[16] = {0}, sd_time[16] = {0}; +#endif + +void SliceCFR::leaf_cfv(int player) { +#ifdef TIME_LOG + Timer timer; + for(int i = 0; i < n_thread; i++) { + fold_time[i].store(0), sd_time[i].store(0); + } +#endif + int opp = 1 - player; + int my_hand = hand_size[player], opp_hand = hand_size[opp]; + vector &vec = pre_leaf_node[player]; + int64_t n = vec.size(); + #pragma omp parallel for schedule(dynamic) + // #pragma omp parallel for + for(int64_t i = 0; i < n; i++) { + // printf("omp_get_thread_num():%d,%zd\n", omp_get_thread_num(), i); + float *cfv = vec[i].cfv; + // for(int j = 0; j < my_hand; j++) cfv[j] = 0; + for(int j : vec[i].leaf_node_idx) { + LeafNode &node = leaf_node[j]; + if(j < sd_offset) { + fold_cfv(player, cfv, node.reach_prob[opp], my_hand, ev_ptr[j], node.info); + } + else sd_cfv(player, cfv, node.reach_prob[opp], my_hand, opp_hand, ev_ptr[j], node.info); + } + } +#ifdef TIME_LOG + for(int i = 0; i < n_thread; i++) { + printf("%zd\t%zd\n", fold_time[i].load(), sd_time[i].load()); + } + // printf("leaf_cfv:%zd ms\n", timer.ms()); +#endif +} +void SliceCFR::get_prob_sum(vector &prob_sum, float &sum, int player, float *reach_prob, size_t board) { + float temp = 0; + int n_hand = hand_size[player], *hand_card = hand_card_ptr[player]; + size_t *hand_hash = hand_hash_ptr[player]; + for(int i = 0; i < n_hand; i++) { + if(hand_hash[i] & board) continue;// 对方手牌与公共牌冲突 + temp = reach_prob[i]; + prob_sum[hand_card[i]] += temp;// card1 + prob_sum[hand_card[i+n_hand]] += temp;// card2 + sum += temp; + } +} +void SliceCFR::fold_cfv(int player, float *cfv, float *opp_reach, int my_hand, float val, size_t board) { +#ifdef TIME_LOG + Timer timer; +#endif + if(player != P0) val = -val; + size_t *my_hash = hand_hash_ptr[player]; + int *my_card = hand_card_ptr[player]; + int *same_hand = same_hand_ptr[player], i = 0; + vector opp_prob_sum(n_card, 0); + float prob_sum = 0, temp = 0; + get_prob_sum(opp_prob_sum, prob_sum, 1-player, opp_reach, board); + for(i = 0; i < my_hand; i++) { + if(my_hash[i] & board) { + // cfv[i] = 0;// 与公共牌冲突,cfv为0 + continue; + } + temp = same_hand[i] != -1 ? opp_reach[same_hand[i]] : 0;// 重复计算的部分 + cfv[i] += (prob_sum - opp_prob_sum[my_card[i]] - opp_prob_sum[my_card[i+my_hand]] + temp) * val; + } +#ifdef TIME_LOG + fold_time[omp_get_thread_num()] += timer.us(); +#endif +} +void SliceCFR::sd_cfv(int player, float *cfv, float *opp_reach, int my_hand, int opp_hand, float val, int idx) { +#ifdef TIME_LOG + Timer timer; +#endif + vector &vec = strength[idx]; + const RiverCombs *my_data = vec[player].data, *opp_data = vec[1-player].data; + int my_size = vec[player].size, opp_size = vec[1-player].size, i = 0, j = 0, h = 0, rank = 0; + int *my_card = hand_card_ptr[player], *opp_card = hand_card_ptr[1-player]; + vector opp_prob_sum(n_card, 0); + float prob_sum = 0; + for(i = 0, j = 0; i < my_size; i++) {// strength值变小,己方手牌变强 + rank = my_data[i].rank; + for(; j < opp_size && opp_data[j].rank > rank; j++) {// (胜过对方条件下)找到对方的最强手牌 + h = opp_data[j].reach_prob_index; + opp_prob_sum[opp_card[h]] += opp_reach[h];// card1 + opp_prob_sum[opp_card[h+opp_hand]] += opp_reach[h];// card2 + prob_sum += opp_reach[h]; + } + h = my_data[i].reach_prob_index; + cfv[h] += (prob_sum - opp_prob_sum[my_card[h]] - opp_prob_sum[my_card[h+my_hand]]) * val; + } + prob_sum = 0; + for(h = 0; h < n_card; h++) opp_prob_sum[h] = 0; + for(i = my_size-1, j = opp_size-1; i >= 0; i--) {// strength值变大,己方手牌变弱 + rank = my_data[i].rank; + for(; j >= 0 && opp_data[j].rank < rank; j--) {// (败给对方条件下)找到对方的最弱手牌 + h = opp_data[j].reach_prob_index; + opp_prob_sum[opp_card[h]] += opp_reach[h];// card1 + opp_prob_sum[opp_card[h+opp_hand]] += opp_reach[h];// card2 + prob_sum += opp_reach[h]; + } + h = my_data[i].reach_prob_index; + cfv[h] += (opp_prob_sum[my_card[h]] + opp_prob_sum[my_card[h+my_hand]] - prob_sum) * val; + } +#ifdef TIME_LOG + sd_time[omp_get_thread_num()] += timer.us(); +#endif +} +void SliceCFR::append_node_idx(int p_idx, int act_idx, int player, int leaf_node_idx) { + if(p_idx == -1) { + root_child_idx[player].push_back(leaf_node_idx); + leaf_node[leaf_node_idx].reach_prob[player] = root_prob_ptr[player]; + return; + } + vector &vec = pre_leaf_node[player]; + int n_hand = hand_size[player], offset = reach_prob_to_cfv(dfs_node[p_idx].n_act, n_hand); + float *cfv = player_node[dfs_idx_map[p_idx]].data + cfv_offset(n_hand, act_idx); + if(pre_leaf_node_map[p_idx].empty()) pre_leaf_node_map[p_idx] = vector(dfs_node[p_idx].n_act, -1); + int &i = pre_leaf_node_map[p_idx][act_idx]; + if(i == -1) {// 未初始化 + i = vec.size(); + vec.emplace_back(cfv); + } + vec[i].leaf_node_idx.push_back(leaf_node_idx); + leaf_node[leaf_node_idx].reach_prob[player] = cfv + offset; +} +size_t SliceCFR::init_leaf_node() { + pre_leaf_node_map = vector>(dfs_idx); + pre_leaf_node = vector>(N_PLAYER); + root_child_idx = vector>(N_PLAYER); + leaf_node = vector(n_leaf_node); + int node_idx = 0; + for(int i = 0; i < N_LEAF_TYPE; i++) { + for(int idx : leaf_node_dfs[i]) { + DFSNode &node = dfs_node[idx]; + append_node_idx(node.parent_p0_idx, node.parent_p0_act, P0, node_idx); + append_node_idx(node.parent_p1_idx, node.parent_p1_act, P1, node_idx); + int j = decode_idx0(node.info), k = decode_idx1(node.info); + size_t info = init_board; + if(i == FOLD_TYPE) { + if(j != -1) info |= 1LL << poss_card[j]; + if(k != -1) info |= 1LL << poss_card[k]; + } + else { + if(j == -1) info = 0; + else if(k == -1) info = j; + else info = tril_idx(max(j, k), min(j, k)); + } + leaf_node[node_idx++].info = info; + } + } + sd_offset = leaf_node_dfs[FOLD_TYPE].size(); + logger->log("%zd,%zd", pre_leaf_node[P0].size(), pre_leaf_node[P1].size()); + logger->log("%d,%d,%zd,%zd", n_leaf_node, node_idx, root_child_idx[P0].size(), root_child_idx[P1].size()); + + size_t max_val[N_PLAYER] = {0, 0}, min_val[N_PLAYER] = {INT_MAX, INT_MAX}; + for(int i = 0; i < N_PLAYER; i++) { + if(!root_child_idx[i].empty()) { + pre_leaf_node[i].emplace_back(root_cfv_ptr[i]); + pre_leaf_node[i].back().leaf_node_idx = root_child_idx[i]; + } + for(PreLeafNode &node : pre_leaf_node[i]) { + assert(node.cfv != nullptr); + max_val[i] = max(max_val[i], node.leaf_node_idx.size()); + min_val[i] = min(min_val[i], node.leaf_node_idx.size()); + } + } + logger->log("%zd,%zd,%zd,%zd", min_val[P0], max_val[P0], min_val[P1], max_val[P1]); + + ev[FOLD_TYPE].insert(ev[FOLD_TYPE].end(), ev[SHOWDOWN_TYPE].begin(), ev[SHOWDOWN_TYPE].end()); + ev[SHOWDOWN_TYPE].clear(); + ev_ptr = ev[FOLD_TYPE].data(); + size_t total = n_leaf_node * sizeof(LeafNode); + total += (pre_leaf_node[P0].size() + pre_leaf_node[P1].size()) * sizeof(PreLeafNode); + total += n_leaf_node * N_PLAYER * sizeof(int);// leaf_node_idx + return total; +} + +SliceCFR::SliceCFR( + shared_ptr tree, + vector &range1, + vector &range2, + vector &initial_board, + shared_ptr compairer, + Deck &deck, + int train_step, + int print_interval, + float accuracy, + int n_thread, + Logger *logger +):deck(deck), steps(train_step), interval(print_interval), n_thread(max(0,n_thread)), rrm(compairer), Solver(tree, logger) { + init_board = Card::boardInts2long(initial_board); + init_round = GameTreeNode::gameRound2int(tree->getRoot()->getRound()); + if(init_round < FLOP_ROUND) return; + init_hand_card(range1, range2); + if(hand_size[P0] == 0 || hand_size[P1] == 0) return; + init_same_hand_idx(); + init_min_card(); + init_poss_card(deck, init_board); + normalization(); + tol = accuracy / 100 * tree->getRoot()->getPot(); + if(this->n_thread == 0) this->n_thread = omp_get_num_procs(); + omp_set_num_threads(this->n_thread); + // test_parallel_for(this->n_thread); +} + +void SliceCFR::init() { + float unit = 1 << 20; + size_t size = estimate_tree_size(); + logger->log("estimate memory:%f MB", size/unit); + + leaf_node_dfs.resize(N_LEAF_TYPE); + ev.resize(N_LEAF_TYPE); + slice.resize(N_PLAYER); + dfs_idx = 0; + dfs(tree->getRoot(), -1, -1, -1, -1, -1, -1, 0, 0, 0); + + print_array(node_cnt, N_TYPE); + for(int i = 0; i < N_LEAF_TYPE; i++) { + printf("%zd,", leaf_node_dfs[i].size()); + assert(node_cnt[i] == leaf_node_dfs[i].size()); + } + for(int player = P0; player < N_PLAYER; player++) { + size = 0; + for(vector &nodes : slice[player]) size += nodes.size(); + printf("%zd,", size); + assert(size == node_cnt[N_LEAF_TYPE+player]); + } + printf("%zd\n", chance_node.size()); + assert(node_cnt[N_LEAF_TYPE+CHANCE_PLAYER] == chance_node.size()); + + if(dfs_idx == 0 || dfs_node[0].n_act == 0) return; + size = init_memory(); + logger->log("%d nodes, total:%f MB", dfs_idx, size/unit); + init_succ = true; +} + +SliceCFR::~SliceCFR() { + for(Node &node : player_node) { + if(node.data) free(node.data); + } +} + +void SliceCFR::set_cfv_and_offset(DFSNode &node, int player, float *&cfv, int &offset) { + if(player == -1) player = node.player;// 向上连接同玩家节点 + int p_idx = node.parent_p0_idx, act_idx = node.parent_p0_act;// 向上连接P0 + if(player != P0) {// 向上连接P1 + p_idx = node.parent_p1_idx; + act_idx = node.parent_p1_act; + } + if(p_idx == -1) { + cfv = root_cfv_ptr[player]; + offset = root_prob_ptr[player] - root_cfv_ptr[player]; + // mtx = (mutex *)player; + } + else { + if(player != dfs_node[p_idx].player) throw runtime_error("player mismatch"); + cfv = player_node[dfs_idx_map[p_idx]].data + cfv_offset(hand_size[player], act_idx); + offset = reach_prob_to_cfv(dfs_node[p_idx].n_act, hand_size[player]); + // if(mtx_map[p_idx].empty()) mtx_map[p_idx] = vector(dfs_node[p_idx].n_act, -1); + // int &i = mtx_map[p_idx][act_idx]; + // if(i == -1) i = mtx_idx++; + // mtx = (mutex *)i; + } +} + +size_t SliceCFR::init_player_node() { + size_t total = 0, size = 0; + player_node = vector(n_player_node); + player_node_ptr = player_node.data(); + dfs_idx_map = vector(dfs_idx, -1); + slice_offset = vector>(N_PLAYER); + // mtx_map = vector>(dfs_idx); + // mtx_idx = N_PLAYER; + int mem_idx = 0; + for(int i = 0; i < N_PLAYER; i++) {// 枚举player + for(vector &nodes : slice[i]) {// 枚举slice + slice_offset[i].push_back(mem_idx); + for(int idx : nodes) {// 枚举node + dfs_idx_map[idx] = mem_idx++; + } + } + slice_offset[i].push_back(mem_idx); + } + for(int idx = 0; idx < dfs_idx; idx++) { + if(dfs_idx_map[idx] == -1) continue; + DFSNode &node = dfs_node[idx]; + if(node.player != P0 && node.player != P1) throw runtime_error("unknow player"); + Node &target = player_node[dfs_idx_map[idx]]; + target.n_act = node.n_act; + set_cfv_and_offset(node, -1, target.parent_cfv, target.parent_offset); + float *ptr = nullptr; + int offset = 0; + set_cfv_and_offset(node, 1-node.player, ptr, offset); + target.opp_prob = ptr + offset; + target.board = init_board; + int j = decode_idx0(node.info), k = decode_idx1(node.info); + if(j != -1) target.board |= 1LL << poss_card[j]; + if(k != -1) target.board |= 1LL << poss_card[k]; + size = get_size(node.n_act, hand_size[node.player]) * sizeof(float); + target.data = (float *)malloc(size); + if(target.data == nullptr) throw runtime_error("malloc error"); + total += size; + } + // mtx = vector(mtx_idx); + // printf("%d,%d,%d\n", sizeof(mutex), mtx_idx, mtx_idx * sizeof(mutex)); + // total += mtx_idx * sizeof(mutex); + // for(int i : dfs_idx_map) { + // if(i == -1) continue; + // player_node[i].mtx = &mtx[(size_t)(player_node[i].mtx)]; + // } + total += n_player_node * sizeof(Node); + return total; +} + +size_t SliceCFR::init_memory() { + size_t total = 0; + int n = root_prob.size(); + root_cfv = vector(n<<1, 0); + for(int i = 0; i < n; i++) root_cfv[n+i] = root_prob[i]; + total += n * 3 * sizeof(float); + root_cfv_ptr[P0] = root_cfv.data(); + root_cfv_ptr[P1] = root_cfv_ptr[P0] + hand_size[P0]; + root_prob_ptr[P0] = root_cfv_ptr[P0] + n; + root_prob_ptr[P1] = root_prob_ptr[P0] + hand_size[P0]; + + total += init_player_node(); + total += init_leaf_node(); + total += init_strength_table(); + return total; +} + +size_t SliceCFR::init_strength_table() { + int n = poss_card.size(); + vector board_hash; + if(init_round == RIVER_ROUND) board_hash.push_back(init_board); + else if(init_round == TURN_ROUND) { + board_hash = vector(n, 0); + for(int i = 0; i < n; i++) board_hash[i] = init_board | (1LL<(n*(n-1)>>1, 0); + for(int i = 0; i < n; i++) { + for(int j = i+1; j < n; j++) { + board_hash[tril_idx(j, i)] = init_board | two_card_hash(poss_card[i], poss_card[j]); + } + } + } + n = board_hash.size(); + strength = vector>(n); + // omp_set_num_threads(omp_get_num_procs()); + #pragma omp parallel for + for(int i = 0; i < n; i++) { + // printf("omp_get_thread_num():%d,%d\n", omp_get_thread_num(), i); + const vector& p0_comb = rrm.getRiverCombos(P0, ranges[P0], board_hash[i]); + const vector& p1_comb = rrm.getRiverCombos(P1, ranges[P1], board_hash[i]); + strength[i].emplace_back(p0_comb.size(), p0_comb.data()); + strength[i].emplace_back(p1_comb.size(), p1_comb.data()); + } + size_t total = (n<<1) * sizeof(StrengthData), size = 0; + for(int i = 0; i < n; i++) size += strength[i][P0].size + strength[i][P1].size; + total += (size<<1) * sizeof(int);// rank,idx + return total; +} + +void SliceCFR::init_min_card() { + min_card = N_CARD; + int max_card = -1; + for(int card : hand_card) { + min_card = min(min_card, card); + max_card = max(max_card, card); + } + n_card = max_card - min_card + 1;// 52张牌中如果只用了连续的一段,可以节省内存 + for(int &card : hand_card) card -= min_card; +} + +void SliceCFR::init_hand_card(vector &range1, vector &range2) { + ranges = vector>(2); + vector cards;// card1,card2,card1,card2,... + init_hand_card(range1, cards, root_prob, init_board, ranges[P0]); + hand_size[P0] = root_prob.size(); + init_hand_card(range2, cards, root_prob, init_board, ranges[P1]); + hand_size[P1] = root_prob.size() - hand_size[P0]; + hand_card = vector(cards.size()); + hand_hash = vector(root_prob.size()); + int stop[N_PLAYER] = {hand_size[P0]<<1, cards.size()}; + int i = 0, j = 0, k = 0, n = 0; + for(int p = 0; p < N_PLAYER; p++) { + for(n = hand_size[p], i = j; j < stop[p]; j += 2, i++) { + hand_card[i] = cards[j]; + hand_card[i+n] = cards[j+1]; + hand_hash[k++] = two_card_hash(cards[j], cards[j+1]); + } + } + hand_card_ptr[P0] = hand_card.data(); + hand_card_ptr[P1] = hand_card_ptr[P0] + stop[P0]; + hand_hash_ptr[P0] = hand_hash.data(); + hand_hash_ptr[P1] = hand_hash_ptr[P0] + hand_size[P0]; +} + +void SliceCFR::init_hand_card(vector &range, vector &cards, vector &prob, size_t board, vector &out) { + unordered_set seen; + for(PrivateCards &hand : range) { + size_t hash = hand.toBoardLong(); + if(seen.count(hash)) continue;// 去重 + if(hash & board) continue;// 和公共牌冲突 + seen.insert(hash); + cards.push_back(min(hand.card1, hand.card2)); + cards.push_back(max(hand.card1, hand.card2)); + prob.push_back(hand.weight); + out.push_back(hand); + } +} + +void SliceCFR::init_same_hand_idx() { + int n = root_prob.size(), p0_size = hand_size[P0]; + same_hand_idx = vector(n, -1); + unordered_map hash2idx; + for(int h = 0; h < p0_size; h++) hash2idx[hand_hash[h]] = h; + for(int h = p0_size; h < n; h++) {// P1 + size_t hash = hand_hash[h]; + if(hash2idx.count(hash)) { + same_hand_idx[h] = hash2idx[hash]; + same_hand_idx[hash2idx[hash]] = h - p0_size; + } + } + same_hand_ptr[P0] = same_hand_idx.data(); + same_hand_ptr[P1] = same_hand_ptr[P0] + p0_size; +} + +void SliceCFR::normalization() { + int p0_size = hand_size[P0], n = root_prob.size(); + norm = 0; + // 每个history的概率为p0_prob*p1_prob*chance_prob*mask/norm + // p0手牌,p1手牌,公共牌之间有冲突时mask=0,无冲突时mask=1 + // cfr迭代过程中,不需要考虑norm + for(int p0 = 0; p0 < p0_size; p0++) { + for(int p1 = p0_size; p1 < n; p1++) { + if(!cards_valid(hand_hash[p0], hand_hash[p1])) continue; + norm += root_prob[p0] * root_prob[p1]; + } + } +} + +size_t SliceCFR::estimate_tree_size() { + for(int i = 0; i < N_TYPE; i++) node_cnt[i] = 0; + if(tree == nullptr) return 0; + size_t size = _estimate_tree_size(tree->getRoot()); + n_leaf_node = node_cnt[FOLD_TYPE] + node_cnt[SHOWDOWN_TYPE]; + n_player_node = node_cnt[N_LEAF_TYPE+P0] + node_cnt[N_LEAF_TYPE+P1]; + size *= sizeof(float); + size += n_leaf_node * sizeof(LeafNode); + size += n_player_node * sizeof(Node); + return size; +} + +size_t SliceCFR::_estimate_tree_size(shared_ptr node) { + int type = node->getType(), round = GameTreeNode::gameRound2int(node->getRound()), n_act = 0; + size_t size = 0; + if(type == GameTreeNode::ACTION) { + shared_ptr act_node = dynamic_pointer_cast(node); + vector> children = act_node->getChildrens(); + n_act = children.size(); + int player = act_node->getPlayer(); + node_cnt[N_LEAF_TYPE + player]++; + size += get_size(n_act, hand_size[player]); + for(int i = 0; i < n_act; i++) size += _estimate_tree_size(children[i]); + } + else if(type == GameTreeNode::CHANCE) { + shared_ptr chance_node = dynamic_pointer_cast(node); + shared_ptr children = chance_node->getChildren();// 不为null + int child_type = children->getType(); + n_act = chance_branch[round] + 4; + node_cnt[N_LEAF_TYPE + CHANCE_PLAYER]++; + if(child_type == GameTreeNode::ACTION || child_type == GameTreeNode::SHOWDOWN) { + for(int i = 0; i < n_act; i++) size += _estimate_tree_size(children); + } + else {// CHANCE之后接着CHANCE,再接着SHOWDOWN + node_cnt[SHOWDOWN_TYPE] += (n_act*(n_act-1)>>1); + } + } + else if(type == GameTreeNode::SHOWDOWN) node_cnt[SHOWDOWN_TYPE]++; + else node_cnt[FOLD_TYPE]++; + return size; +} + +void SliceCFR::dfs(shared_ptr node, int parent_act, int parent_dfs_idx, int parent_p0_act, int parent_p0_idx, int parent_p1_act, int parent_p1_idx, int cnt0, int cnt1, int info) { + int curr_idx = dfs_idx++; + int type = node->getType(), round = GameTreeNode::gameRound2int(node->getRound()), n_act = 0; + if(type == GameTreeNode::ACTION) { + shared_ptr act_node = dynamic_pointer_cast(node); + ActionNode *p = act_node.get(); + int r_offset = round - init_round; + if(node_idx.find(p) == node_idx.end()) node_idx[p] = vector(combination_num[r_offset], -1); + int j = decode_idx0(info), k = decode_idx1(info); + if(r_offset == 0) { + assert(j == -1 && k == -1); + node_idx[p][0] = curr_idx; + } + else if(r_offset == 1) { + assert(j != -1 && k == -1); + node_idx[p][poss_card[j]] = curr_idx; + } + else { + assert(r_offset == 2 && j != -1 && k != -1); + node_idx[p][poss_card[j]*N_CARD+poss_card[k]] = curr_idx; + } + int player = act_node->getPlayer(); + vector> children = act_node->getChildrens(); + n_act = children.size(); + dfs_node.emplace_back(player, n_act, parent_act, info | round, parent_dfs_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx); + vector> &player_slice = slice[player]; + if(player == P0) { + if(player_slice.size() == cnt0) player_slice.emplace_back(); + player_slice[cnt0++].push_back(curr_idx); + for(int i = 0; i < n_act; i++) dfs(children[i], i, curr_idx, i, curr_idx, parent_p1_act, parent_p1_idx, cnt0, cnt1, info); + } + else {// P1 + if(player_slice.size() == cnt1) player_slice.emplace_back(); + player_slice[cnt1++].push_back(curr_idx); + for(int i = 0; i < n_act; i++) dfs(children[i], i, curr_idx, parent_p0_act, parent_p0_idx, i, curr_idx, cnt0, cnt1, info); + } + } + else if(type == GameTreeNode::CHANCE) { + shared_ptr chance_node = dynamic_pointer_cast(node); + shared_ptr children = chance_node->getChildren();// 不为null + int child_type = children->getType(); + n_act = chance_branch[round] + 4; + this->chance_node.push_back(curr_idx); + if(child_type == GameTreeNode::ACTION || child_type == GameTreeNode::SHOWDOWN) {// 需要发1张牌 + dfs_node.emplace_back(CHANCE_PLAYER, n_act, parent_act, info | round, parent_dfs_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx); + // 发牌信息编码,只有1张牌时,占用idx0,有2张牌时,占用idx0,idx1 + int j = decode_idx0(info), new_info = 0; + for(int i = 0, k = 0; i < n_act; i++, k++) {// 动作索引i,poss_card索引k + if(j == -1) new_info = code_idx0(k);// 第一次发牌 + else {// 第二次发牌,最多发两次牌 + if(k == j) k++;// 两次选的一样,则第二次改成下一个 + new_info = code_idx0(j) | code_idx1(k); + } + dfs(children, i, curr_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx, cnt0, cnt1, new_info); + } + } + else {// CHANCE之后接着CHANCE,再接着SHOWDOWN,需要连续发2张牌 + assert(round == TURN_ROUND); + shared_ptr child = dynamic_pointer_cast(children); + assert(child->getChildren()->getType() == GameTreeNode::SHOWDOWN); + int parent_player = dfs_node[parent_dfs_idx].player;// 父节点玩家 + dfs_node.emplace_back(CHANCE_PLAYER, n_act*(n_act-1)>>1, parent_act, info | round, parent_dfs_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx); + // float val = node->getPot()/2*2/chance_den[RIVER_ROUND]; + float val = node->getPot()/chance_den[RIVER_ROUND]; + for(int i = 0, j = 0; j < n_act; j++) { + for(int k = j+1; k < n_act; k++) { + ev[SHOWDOWN_TYPE].push_back(val); + leaf_node_dfs[SHOWDOWN_TYPE].push_back(dfs_idx++); + info = code_idx0(j) | code_idx1(k); + dfs_node.emplace_back(CHANCE_PLAYER, 0, i++, info, curr_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx); + } + } + } + } + else {// river SHOWDOWN, fold + assert(parent_dfs_idx != -1); + int parent_player = dfs_node[parent_dfs_idx].player;// 父节点玩家 + int i = SHOWDOWN_TYPE; + float val = 0; + if(type == GameTreeNode::SHOWDOWN) val = node->getPot()/2; + else {// fold + vector pot = dynamic_pointer_cast(node)->get_payoffs(); + val = pot[P0]; + i = FOLD_TYPE; + } + leaf_node_dfs[i].push_back(curr_idx); + ev[i].push_back(val / chance_den[round]); + dfs_node.emplace_back(parent_player, 0, parent_act, info, parent_dfs_idx, parent_p0_act, parent_p0_idx, parent_p1_act, parent_p1_idx); + } +} + +void SliceCFR::init_poss_card(Deck& deck, size_t board) { + vector &cards = deck.getCards(); + for(Card& card : cards) { + int i = card.getCardInt(); + if(cards_valid(1LL<= 0; r--) chance_branch[r] = poss_card.size() - 4;// 排除2个玩家的手牌,总共4张 + for(int r = init_round+2; r < N_ROUND; r++) chance_branch[r] = chance_branch[r-1] - 1; + print_array(chance_branch, N_ROUND); + for(int r = 0; r <= init_round; r++) chance_den[r] = 1; + for(int r = init_round+1; r < N_ROUND; r++) chance_den[r] = chance_den[r-1] * chance_branch[r]; + print_array(chance_den, N_ROUND); +} + +void SliceCFR::_reach_prob(int player, bool avg_strategy) { + vector& offset = slice_offset[player]; + int n = offset.size(), n_hand = hand_size[player]; + node_func func = avg_strategy ? reach_prob_avg : reach_prob; + for(int i = 1; i < n; i++) { + #pragma omp parallel for + for(int j = offset[i-1]; j < offset[i]; j++) { + func(player_node_ptr+j, n_hand); + } + } +} +void SliceCFR::_rm(int player, bool avg_strategy) { + node_func func = avg_strategy ? rm_avg : rm; + int s = slice_offset[player][0], e = slice_offset[player].back(), n_hand = hand_size[player]; + #pragma omp parallel for + for(int i = s; i < e; i++) { + func(player_node_ptr+i, n_hand); + } +} + +void SliceCFR::clear_data(int player) { + int s = slice_offset[player][0], e = slice_offset[player].back(), n_hand = hand_size[player]; + size_t size = 0; + for(int i = s; i < e; i++) { + size = get_size(player_node_ptr[i].n_act, n_hand) * sizeof(float); + memset(player_node_ptr[i].data, 0, size); + } +} + +void SliceCFR::clear_root_cfv() { + size_t size = root_prob.size() * sizeof(float); + memset(root_cfv_ptr[P0], 0, size); +} + +bool SliceCFR::print_exploitability(int iter, Timer &timer) { + vector res = exploitability(); + logger->log("%d:%.3fs", iter, timer.ms()/1000.0); + float avg = (res[0] + res[1]) / 2; + logger->log("%d:%f %f %f", iter, res[0], res[1], avg); + return avg <= tol; +} + +void SliceCFR::train() { + init(); + if(!init_succ) return; + Timer timer; + clear_data(P0); + clear_data(P1); + // _rm(P0, false); + // _rm(P1, false); + // _reach_prob(P0, false); + print_exploitability(0, timer); + // 计算exploitability后,双方的rm和p0的reach_prob已经恢复 + // pos_coef = neg_coef = coef = 0; + double temp = 0; + int cnt = 0, iter = 0; + while(iter < steps) { + temp = pow(iter, alpha); + pos_coef = temp / (temp + 1); + temp = pow(iter, beta); + neg_coef = temp / (temp + 1); + // neg_coef = 0.5; + coef = pow((float)iter/(iter+1), gamma); + + clear_root_cfv(); + for(int player = P0; player < N_PLAYER; player++) { + step(iter, player, CFR_TASK); + } + iter++; + if((++cnt) == interval) { + cnt = 0; + if(print_exploitability(iter, timer)) break; + } + if(stop_flag) break; + } + if(cnt) { + print_exploitability(iter, timer); + } + logger->log("collecting statics"); + for(int player = P0; player < N_PLAYER; player++) { + _rm(1-player, true); + step(iter, player, CFV_TASK); + } + cfv_to_ev(); + logger->log("statics collected"); +} + +// 执行更新任务时,player到达概率需要提前计算好 +void SliceCFR::step(int iter, int player, int task) { +#ifdef TIME_LOG + size_t start = timeSinceEpochMillisec(), end = 0; +#endif + int opp = 1 - player, my_hand = hand_size[player]; + _reach_prob(opp, task != CFR_TASK); +#ifdef TIME_LOG + end = timeSinceEpochMillisec(); + size_t t1 = end - start; + start = end; +#endif + + leaf_cfv(player); +#ifdef TIME_LOG + end = timeSinceEpochMillisec(); + size_t t2 = end - start; + start = end; +#endif + + vector& offset = slice_offset[player]; + if(task == CFR_TASK) { + #pragma omp parallel for + for(int j = offset[0]; j < offset.back(); j++) { + discount_data(player_node_ptr+j, my_hand, pos_coef, neg_coef, coef); + } + } +#ifdef TIME_LOG + end = timeSinceEpochMillisec(); + size_t t3 = end - start; + start = end; +#endif + + node_func func; + switch(task) { + case EXP_TASK:{func = best_cfv_up;break;} + case CFV_TASK:{func = cfv_up_avg;break;} + default:func = cfv_up; + } + for(int i = offset.size()-1; i > 0; i--) { + #pragma omp parallel for + for(int j = offset[i-1]; j < offset[i]; j++) { + func(player_node_ptr+j, my_hand); + } + } +#ifdef TIME_LOG + end = timeSinceEpochMillisec(); + size_t t4 = end - start; + printf("%zd\t%zd\t%zd\t%zd\n", t1, t2, t3, t4); +#endif +} + +vector SliceCFR::exploitability() { + int opp = 0; + clear_root_cfv(); + for(int player = P0; player < N_PLAYER; player++) { +#ifdef TIME_LOG + size_t start = timeSinceEpochMillisec(); +#endif + opp = 1 - player; + _rm(opp, true);// 改变对方策略 +#ifdef TIME_LOG + size_t t1 = timeSinceEpochMillisec() - start; +#endif + step(0, player, EXP_TASK); +#ifdef TIME_LOG + start = timeSinceEpochMillisec(); +#endif + _rm(opp, false);// 恢复对方策略 +#ifdef TIME_LOG + size_t t2 = timeSinceEpochMillisec() - start; + printf("rm time:%zd\t%zd\n", t1, t2); +#endif + } + post_process(); + _reach_prob(P0, false);// 恢复P0的reach_prob,用于下一次迭代 + int m = 0, n = hand_size[P0]; + float ev0 = 0, ev1 = 0; + for(int i = m; i < n; i++) ev0 += root_cfv[i] * root_prob[i]; + m = n; n = root_prob.size(); + for(int i = m; i < n; i++) ev1 += root_cfv[i] * root_prob[i]; + return {ev0/norm, ev1/norm}; +} + +void SliceCFR::stop() { + stop_flag = true; +} +json SliceCFR::dumps(bool with_status, int depth) {// depth:max_round + int idx = 0; + json ans = reConvertJson(tree->getRoot(), 0, depth, idx, 0); + if(idx != dfs_idx) throw runtime_error("dfs idx error"); + return std::move(ans); +} +vector>> SliceCFR::get_strategy(shared_ptr node, vector cards) { + vector>> ans(N_CARD, vector>(N_CARD)); + output_data(node.get(), cards, ans, false); + return std::move(ans); +} +vector>> SliceCFR::get_evs(shared_ptr node, vector cards) { + vector>> ans(N_CARD, vector>(N_CARD)); + output_data(node.get(), cards, ans, true); + return std::move(ans); +} +void SliceCFR::output_data(ActionNode *node, vector &cards, vector>> &out, bool ev) { + int r_offset = GameTreeNode::gameRound2int(node->getRound()) - init_round; + if(cards.size() != r_offset || r_offset > 2) throw runtime_error("chance_cards error"); + int idx = 0; + size_t board = init_board; + if(r_offset >= 1) { + idx = cards[0].getCardInt(); + board |= 1LL << idx; + } + if(r_offset == 2) { + idx = idx * N_CARD + cards[1].getCardInt(); + board |= 1LL << cards[1].getCardInt(); + } + vector> data; + if(ev) data = get_ev(node_idx.at(node)[idx]); + else data = get_avg_strategy(node_idx.at(node)[idx]); + int player = node->getPlayer(), n_hand = hand_size[player], *card = hand_card_ptr[player]; + size_t *ptr = hand_hash_ptr[player]; + for(int h = 0; h < n_hand; h++) { + if(!cards_valid(ptr[h], board)) continue; + out[card[h]+min_card][card[h+n_hand]+min_card].swap(data[h]); + } +} +vector> SliceCFR::get_ev(int idx) { + Node &node = player_node[dfs_idx_map[idx]]; + int n_hand = hand_size[dfs_node[idx].player], n_act = node.n_act; + int i = 0, h = 0, j = 0; + float *cfv = node.data; + vector> ev(n_hand, vector(n_act));// [n_hand,n_act] + for(j = 0; j < n_act; j++) { + for(h = 0; h < n_hand; h++) ev[h][j] = cfv[i++]; + } + return std::move(ev); +} +vector> SliceCFR::get_avg_strategy(int idx) { + Node &node = player_node[dfs_idx_map[idx]]; + int n_hand = hand_size[dfs_node[idx].player], n_act = node.n_act; + int size = n_act * n_hand, i = 0, h = 0, j = 0; + float sum = 0, *strategy_sum = node.data + (size << 1), uni = 1.0 / n_act; + vector> strategy(n_hand, vector(n_act));// [n_hand,n_act] + for(h = 0; h < n_hand; h++) { + sum = 0; + for(i = h; i < size; i += n_hand) sum += strategy_sum[i]; + if(sum == 0) { + for(j = 0; j < n_act; j++) strategy[h][j] = uni; + } + else { + for(j = 0, i = h; j < n_act; j++, i += n_hand) strategy[h][j] = strategy_sum[i] / sum; + } + } + return std::move(strategy); +} +json SliceCFR::reConvertJson(const shared_ptr& node, int depth, int max_depth, int &idx, int info) { + int curr_idx = idx++; + int type = node->getType(), n_act = 0; + json ans; + if(type == GameTreeNode::ACTION) { + shared_ptr one_node = dynamic_pointer_cast(node); + vector actions_str; + if(depth < max_depth) { + int player = one_node->getPlayer(); + for(GameActions one_action : one_node->getActions()) actions_str.push_back(one_action.toString()); + ans["actions"] = actions_str; + ans["player"] = player; + ans["node_type"] = "action_node"; + + vector> strategy = get_avg_strategy(curr_idx); + ans["strategy"] = json(); + ans["strategy"]["actions"] = actions_str; + json stt; + size_t n_hand = hand_size[player]; + int *ptr = hand_card_ptr[player]; + for(size_t i = 0; i < n_hand; i++) { + stt[Card::intCard2Str(ptr[i+n_hand]+min_card)+Card::intCard2Str(ptr[i]+min_card)] = strategy[i]; + } + ans["strategy"]["strategy"] = std::move(stt); + + ans["childrens"] = json(); + } + vector> children = one_node->getChildrens(); + n_act = children.size(); + for(int i = 0; i < n_act; i++) { + json child = reConvertJson(children[i], depth, max_depth, idx, info); + if(depth < max_depth) ans["childrens"][actions_str[i]] = child; + } + } + else if(type == GameTreeNode::CHANCE) { + if((++depth) <= max_depth) ans["node_type"] = "chance_node"; + shared_ptr chance_node = dynamic_pointer_cast(node); + shared_ptr children = chance_node->getChildren();// 不为null + int child_type = children->getType(); + n_act = chance_branch[GameTreeNode::gameRound2int(node->getRound())] + 4; + if(child_type == GameTreeNode::ACTION || child_type == GameTreeNode::SHOWDOWN) {// 需要发1张牌 + if(depth <= max_depth) ans["deal_number"] = n_act; + if(depth < max_depth) ans["dealcards"] = json();// 需要展开子节点 + int j = decode_idx0(info), new_info = 0; + for(int i = 0, k = 0; i < n_act; i++, k++) {// 动作索引i,poss_card索引k + if(j == -1) new_info = code_idx0(k);// 第一次发牌 + else {// 第二次发牌,最多发两次牌 + if(k == j) k++;// 两次选的一样,则第二次改成下一个 + } + json child = reConvertJson(children, depth, max_depth, idx, new_info); + if(depth < max_depth) ans["dealcards"][Card::intCard2Str(poss_card[k])] = child; + } + } + else { + n_act = n_act*(n_act-1)>>1; + idx += n_act; + if(depth <= max_depth) ans["deal_number"] = n_act; + } + } + // else {} + return std::move(ans); +} \ No newline at end of file diff --git a/src/tools/CommandLineTool.cpp b/src/tools/CommandLineTool.cpp index 03e8108..f799a8a 100644 --- a/src/tools/CommandLineTool.cpp +++ b/src/tools/CommandLineTool.cpp @@ -2,39 +2,24 @@ // Created by bytedance on 7.6.21. // #include "include/tools/CommandLineTool.h" -#include - -CommandLineTool::CommandLineTool(string mode,string resource_dir) { - string suits = "c,d,h,s"; - string ranks; - this->resource_dir = resource_dir; - string compairer_file,compairer_file_bin; - int lines; - if(mode == "holdem"){ - ranks = "2,3,4,5,6,7,8,9,T,J,Q,K,A"; - compairer_file = this->resource_dir + "/compairer/card5_dic_sorted.txt"; - compairer_file_bin = this->resource_dir + "/compairer/card5_dic_zipped.bin"; - lines = 2598961; - }else if(mode == "shortdeck"){ - ranks = "6,7,8,9,T,J,Q,K,A"; - compairer_file = this->resource_dir + "/compairer/card5_dic_sorted_shortdeck.txt"; - compairer_file_bin = this->resource_dir + "/compairer/card5_dic_zipped_shortdeck.bin"; - lines = 376993; - }else{ - throw runtime_error(tfm::format("mode not recognized : ",mode)); - } - string logfile_name = "../resources/outputs/outputs_log.txt"; - this->ps = PokerSolver(ranks,suits,compairer_file,lines,compairer_file_bin); +// #include +#include +#include +#include + +CommandLineTool::CommandLineTool() { + // string logfile_name = "../resources/outputs/outputs_log.txt"; + // this->ps = PokerSolver(mode, resource_dir); - StreetSetting gbs_flop_ip = StreetSetting(vector{},vector{},vector{},true); - StreetSetting gbs_turn_ip = StreetSetting(vector{},vector{},vector{},true); - StreetSetting gbs_river_ip = StreetSetting(vector{},vector{},vector{},true); + // StreetSetting gbs_flop_ip = StreetSetting(vector{},vector{},vector{},true); + // StreetSetting gbs_turn_ip = StreetSetting(vector{},vector{},vector{},true); + // StreetSetting gbs_river_ip = StreetSetting(vector{},vector{},vector{},true); - StreetSetting gbs_flop_oop = StreetSetting(vector{},vector{},vector{},true); - StreetSetting gbs_turn_oop = StreetSetting(vector{},vector{},vector{},true); - StreetSetting gbs_river_oop = StreetSetting(vector{},vector{},vector{},true); + // StreetSetting gbs_flop_oop = StreetSetting(vector{},vector{},vector{},true); + // StreetSetting gbs_turn_oop = StreetSetting(vector{},vector{},vector{},true); + // StreetSetting gbs_river_oop = StreetSetting(vector{},vector{},vector{},true); - this->gtbs = make_shared(gbs_flop_ip,gbs_turn_ip,gbs_river_ip,gbs_flop_oop,gbs_turn_oop,gbs_river_oop); + // this->gtbs = make_shared(gbs_flop_ip,gbs_turn_ip,gbs_river_ip,gbs_flop_oop,gbs_turn_oop,gbs_river_oop); //ps.build_game_tree(oop_commit,ip_commit,current_round,raise_limit,small_blind,big_blind,stack,*gtbs.get(),allin_threshold); //cout << "build tree finished" << endl; /* @@ -56,98 +41,156 @@ CommandLineTool::CommandLineTool(string mode,string resource_dir) { */ } -void CommandLineTool::startWorking() { +void CommandLineTool::startWorking(PokerSolver *ps) { string input_line; while(cin) { getline(cin, input_line); - this->processCommand(input_line); + this->processCommand(input_line, ps); }; } -void CommandLineTool::execFromFile(string input_file){ +void CommandLineTool::execFromFile(const char *input_file, PokerSolver *ps) { std::ifstream infile(input_file); std::string input_line; while (std::getline(infile, input_line)) { - this->processCommand(input_line); + this->processCommand(input_line, ps); } } -void split(const string& s, char c, - vector& v) { - string::size_type i = 0; - string::size_type j = s.find(c); - - while (j != string::npos) { +void split(const string& s, char delimiter, vector& v) { + size_t i = s.find_first_not_of(delimiter), j = 0; + while (i != string::npos) { + j = s.find_first_of(delimiter, i+1); + if(j == string::npos) j = s.size(); v.push_back(s.substr(i, j-i)); - i = ++j; - j = s.find(c, j); + i = s.find_first_not_of(delimiter, j+1); + } +} + +template +string tostring(T val) { + string s = to_string(val); + for(size_t i = s.size() - 1; i > 0; i--) { + if(s[i] == '0') s.pop_back(); + else if(s[i] == '.') { + s.pop_back(); + break; + } + else break; + } + return s; +} + +template +string tostring_oss(T val) { + ostringstream oss; + oss << val; + return oss.str(); +} - if (j == string::npos) - v.push_back(s.substr(i, s.length())); +void join(const vector &vec, char delimiter, string &out) { + size_t n = vec.size(); + if(n) out += tostring(vec[0]); + for(int i = 1; i < n; i++) { + out += delimiter; + out += tostring(vec[i]); } } +bool CommandLineTool::set_board(string &str) { + board = str; + vector board_str_arr = string_split(board,','); + if(board_str_arr.size() == 3){ + this->current_round = 1; + }else if(board_str_arr.size() == 4){ + this->current_round = 2; + }else if(board_str_arr.size() == 5){ + this->current_round = 3; + }else{ + // throw runtime_error(tfm::format("board %s not recognized",this->board)); + return false; + } + return true; +} -void CommandLineTool::processCommand(string input) { +bool CommandLineTool::set_bet_sizes(string &str, char delimiter, vector *sizes) { + vector params; + split(str, delimiter, params); + int start = (sizes != nullptr ? 0 : 3); + if(params.size() < start) { + // throw runtime_error("param number error"); + return false; + } + if(sizes == nullptr) { + // oop,turn,bet,30,70,100 + StreetSetting& streetSetting = gtbs.get_setting(params[0], params[1]); + string &bet_type = params[2]; + if(bet_type == "allin") { + if(params.size() == start) streetSetting.allin = true; + else streetSetting.allin = stoi(params[start]); + } + else if(bet_type == "bet") sizes = &(streetSetting.bet_sizes); + else if(bet_type == "raise") sizes = &(streetSetting.raise_sizes); + else if(bet_type == "donk") sizes = &(streetSetting.donk_sizes); + else return false; + } + if(sizes != nullptr) { + sizes->clear(); + std::unordered_set seen; + for(std::size_t i = start; i < params.size(); i++) { + float val = stof(params[i]); + if(seen.count(val)) continue; + sizes->push_back(val); + seen.insert(val); + } + std::sort(sizes->begin(), sizes->end()); + } + return true; +} + +// void show_bet_sizes(std::ofstream &out, const char *player, const char *round, const char *type, vector &sizes) { +// string s; +// join(sizes, ',', s); +// out << "set_bet_sizes " << player << ',' << round << ',' << type; +// if(s.size()) out << ',' << s; +// out << endl; +// } +// void show_bet_sizes(std::ofstream &out, const char *player, const char *round, const char *type, bool allin) { +// out << "set_bet_sizes " << player << ',' << round << ',' << type << ',' << allin; +// } + +void CommandLineTool::processCommand(string &input, PokerSolver *ps) { vector contents; + if(input.empty() || input[0] == '#') return; split(input,' ',contents); if(contents.size() == 0) contents = {input}; if(contents.size() > 2 || contents.size() < 1)throw runtime_error(tfm::format("command not valid: %s",input)); string command = contents[0]; string paramstr = contents.size() == 1 ? "" : contents[1]; if(command == "set_pot"){ - this->ip_commit = stof(paramstr) / 2; - this->oop_commit = stof(paramstr) / 2; + set_pot(stof(paramstr)); }else if(command == "set_effective_stack"){ - this->stack = stof(paramstr) + this->ip_commit; + set_effective_stack(stof(paramstr)); }else if(command == "set_board"){ - this->board = paramstr; - vector board_str_arr = string_split(board,','); - if(board_str_arr.size() == 3){ - this->current_round = 1; - }else if(board_str_arr.size() == 4){ - this->current_round = 2; - }else if(board_str_arr.size() == 5){ - this->current_round = 3; - }else{ - throw runtime_error(tfm::format("board %s not recognized",this->board)); - } + set_board(paramstr); }else if(command == "set_range_ip"){ this->range_ip = paramstr; }else if(command == "set_range_oop"){ this->range_oop = paramstr; }else if(command == "set_bet_sizes"){ - vector params; - split(paramstr,',',params); - if(params.size() < 3)throw runtime_error("param number error"); - // oop,turn,bet,30,70,100 - string player = params[0]; - string round = params[1]; - string bet_type = params[2]; - StreetSetting& streetSetting = this->gtbs->get_setting(player,round); - vector* sizes; - if(bet_type == "allin") streetSetting.allin = true; - else if(bet_type == "bet") sizes = &(streetSetting.bet_sizes); - else if(bet_type == "raise") sizes = &(streetSetting.raise_sizes); - else if(bet_type == "donk") sizes = &(streetSetting.donk_sizes); - else throw runtime_error(""); - - if(bet_type == "bet" || bet_type == "raise" || bet_type == "donk"){ - sizes->clear(); - for(std::size_t i = 3;i < params.size();i ++ ){ - sizes->push_back(stof(params[i])); - } - } + set_bet_sizes(paramstr); + }else if(command == "set_raise_limit"){ + this->raise_limit = stoi(paramstr); }else if(command == "set_accuracy"){ this->accuracy = stof(paramstr); }else if(command == "set_allin_threshold"){ this->allin_threshold = stof(paramstr); }else if(command == "set_thread_num"){ - this->thread_number = stoi(paramstr); + this->thread_num = stoi(paramstr); }else if(command == "build_tree"){ - this->ps.build_game_tree(oop_commit,ip_commit,current_round,raise_limit,small_blind,big_blind,stack,*gtbs.get(),allin_threshold); + build_tree(ps); }else if(command == "set_max_iteration"){ this->max_iteration = stoi(paramstr); }else if(command == "set_use_isomorphism"){ @@ -155,27 +198,103 @@ void CommandLineTool::processCommand(string input) { }else if(command == "set_print_interval"){ this->print_interval = stoi(paramstr); }else if(command == "start_solve"){ - cout << "<<>>" << endl; - this->ps.train( - this->range_ip, - this->range_oop, - this->board, - "tmp_log.txt", - max_iteration, - this->print_interval, - "discounted_cfr", - -1, - this->accuracy, - this->use_isomorphism, - 0, // TODO: enable half float option for command line tool - this->thread_number - ); + start_solve(ps); + }else if(command == "dump_setting"){ + dump_setting(paramstr.c_str()); }else if(command == "dump_result"){ - string output_file = paramstr; - this->ps.dump_strategy(QString::fromStdString(output_file),this->dump_rounds); + res_file = paramstr; + if(!ps) return; + ps->dump_strategy(res_file, this->dump_rounds); }else if(command == "set_dump_rounds"){ this->dump_rounds = stoi(paramstr); + }else if(command == "estimate_tree_memory"){ + if(!ps) return; + if(range_ip.empty() || range_oop.empty() || board.empty()) { + // cout << "Please set range_ip, range_oop and board first." << endl; + logger->log("Please set range_ip, range_oop and board first."); + return; + } + shared_ptr game_tree = ps->get_game_tree(); + if(game_tree == nullptr) { + // cout << "Please buld tree first." << endl; + logger->log("Please buld tree first."); + return; + } + long long size = ps->estimate_tree_memory(range_ip, range_oop, board); + size *= sizeof(float); + // cout << (float)size / (1024*1024) << " MB" << endl; + logger->log("estimate_tree_memory: %f MB", (float)size / (1024*1024)); + }else if(command == "set_slice_cfr"){ + slice_cfr = stoi(paramstr); }else{ - cout << "command not recognized: " << command << endl; + // cout << "command not recognized: " << command << endl; + logger->log("command not recognized: %s", command.c_str()); } } + +void CommandLineTool::dump_setting(const char *file) { + static vector player {"oop","ip"}; + static vector round {"flop","turn","river"}; + static vector type {"bet","raise","donk","allin"}; + std::ofstream out(file); + out << "set_pot " << get_pot() << endl; + out << "set_effective_stack " << get_effective_stack() << endl; + out << "set_board " << board << endl; + out << "set_range_oop " << range_oop << endl; + out << "set_range_ip " << range_ip << endl; + + for(size_t i = 0; i < player.size(); i++) { + for(size_t j = 0; j < round.size(); j++) { + for(size_t k = 0; k < type.size(); k++) { + if(k == 2 && (i == 1 || j == 0)) continue;// no donk:ip, oop flop + out << "set_bet_sizes " << player[i] << ',' << round[j] << ',' << type[k] << ','; + StreetSetting& st = gtbs.get_setting(player[i], round[j]); + if(k == 3) out << st.allin; + else { + vector &vec = (k == 0 ? st.bet_sizes : (k == 1 ? st.raise_sizes : st.donk_sizes)); + string str; + join(vec, ',', str); + out << str; + } + out << endl; + } + } + } + out << "set_allin_threshold " << allin_threshold << endl; + out << "set_raise_limit " << raise_limit << endl; + out << "build_tree" << endl; + out << "set_thread_num " << thread_num << endl; + out << "set_accuracy " << accuracy << endl; + out << "set_max_iteration " << max_iteration << endl; + out << "set_print_interval " << print_interval << endl; + out << "set_use_isomorphism " << use_isomorphism << endl; + out << "set_slice_cfr " << slice_cfr << endl; + out << "start_solve" << endl; + out << "set_dump_rounds " << dump_rounds << endl; + out << "dump_result " << res_file << endl; + out.close(); +} + +int cmd_api(string &input_file, string &resource_dir, string &mode, string &log_file) { + if(resource_dir.empty()){ + resource_dir = "./resources"; + } + if(log_file.empty()) log_file = get_localtime() + ".txt"; + Logger logger(true, log_file.c_str(), "w+", true, true, 1); + PokerMode poker_mode = PokerMode::UNKNOWN; + if(mode.empty() || mode == "holdem") poker_mode = PokerMode::HOLDEM; + else if(mode == "shortdeck") poker_mode = PokerMode::SHORTDECK; + else throw runtime_error(tfm::format("mode %s error, not in ['holdem','shortdeck']", mode)); + PokerSolver ps = PokerSolver(poker_mode, resource_dir); + CommandLineTool clt; + clt.logger = &logger; + ps.logger = &logger; + if(input_file.empty()) { + clt.startWorking(&ps); + }else{ + // cout << "EXEC FROM FILE" << endl; + logger.log("EXEC FROM FILE"); + clt.execFromFile(input_file.c_str(), &ps); + } + return 0; +} \ No newline at end of file diff --git a/src/tools/GameTreeBuildingSettings.cpp b/src/tools/GameTreeBuildingSettings.cpp index a815a4c..422c3ed 100644 --- a/src/tools/GameTreeBuildingSettings.cpp +++ b/src/tools/GameTreeBuildingSettings.cpp @@ -13,7 +13,7 @@ GameTreeBuildingSettings::GameTreeBuildingSettings( StreetSetting river_oop):flop_ip(flop_ip),turn_ip(turn_ip),river_ip(river_ip),flop_oop(flop_oop),turn_oop(turn_oop),river_oop(river_oop) { } -StreetSetting& GameTreeBuildingSettings::get_setting(string player,string round){ +StreetSetting& GameTreeBuildingSettings::get_setting(string &player, string &round){ if(player == "ip" && round == "flop") return flop_ip; else if(player == "ip" && round == "turn") return turn_ip; else if(player == "ip" && round == "river") return river_ip; diff --git a/src/tools/logger.cpp b/src/tools/logger.cpp new file mode 100644 index 0000000..ebc534c --- /dev/null +++ b/src/tools/logger.cpp @@ -0,0 +1,57 @@ +#include "include/tools/logger.h" +#include + +void get_localtime(char *buf, size_t n, const char *format) { + using namespace std::chrono; + system_clock::time_point tp = system_clock::now(); + time_t now = system_clock::to_time_t(tp); + // time(&now); + int ms = duration_cast(tp.time_since_epoch()).count() - now * 1000; + tm tm_now; +#ifdef _MSC_VER + localtime_s(&tm_now, &now); +#else + localtime_r(&now, &tm_now); +#endif + // strftime(buf, n, format, &tm_now); + snprintf(buf, n, format, tm_now.tm_year+1900, tm_now.tm_mon+1, tm_now.tm_mday, + tm_now.tm_hour, tm_now.tm_min, tm_now.tm_sec, ms); +} + +string get_localtime() { + char buf[25]; + get_localtime(buf, sizeof(buf), "%d_%02d_%02d_%02d_%02d_%02d.%03d"); + return string(buf); +} + +void Logger::log(const char *format, ...) { + if(timestamp) log_time(); + va_list args; + va_start(args, format); + if(file) { + vfprintf(file, format, args); + if((++step) == period) { + step = 0; + fflush(file); + } + if(new_line) fprintf(file, "\n"); +#ifdef __GNUC__ + if(cmd) { + va_end(args); + va_start(args, format); + } +#endif + } + if(cmd) { + vprintf(format, args); + if(new_line) printf("\n"); + } + va_end(args); +} + +void Logger::log_time() { + char buf[28]; + get_localtime(buf, sizeof(buf), "%d-%02d-%02d %02d:%02d:%02d.%03d "); + if(file) fprintf(file, buf); + if(cmd) printf(buf); +} \ No newline at end of file diff --git a/src/ui/boardselectortabledelegate.cpp b/src/ui/boardselectortabledelegate.cpp index a2b0afe..a281058 100644 --- a/src/ui/boardselectortabledelegate.cpp +++ b/src/ui/boardselectortabledelegate.cpp @@ -28,7 +28,8 @@ void BoardSelectorTableDelegate::paint(QPainter *painter, const QStyleOptionView painter->fillRect(rect, brush); QTextDocument doc; - doc.setHtml(Card(options.text.toStdString()).toFormattedHtml()); + Card card(options.text.toStdString()); + doc.setHtml(toFormattedHtml(card)); painter->translate(options.rect.left(), options.rect.top()); QRect clip(0, 0, options.rect.width(), options.rect.height()); diff --git a/src/ui/detailitemdelegate.cpp b/src/ui/detailitemdelegate.cpp index 4d6040b..d79cedd 100644 --- a/src/ui/detailitemdelegate.cpp +++ b/src/ui/detailitemdelegate.cpp @@ -114,8 +114,8 @@ void DetailItemDelegate::paint_strategy(QPainter *painter, const QStyleOptionVie } } options.text = ""; - options.text += detailViewerModel->tableStrategyModel->cardint2card[card1].toFormattedHtml(); - options.text += detailViewerModel->tableStrategyModel->cardint2card[card2].toFormattedHtml(); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[card1]); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[card2]); options.text = "

" + options.text + "<\/h2>"; for(std::size_t i = 0;i < strategy.size();i ++){ GameActions one_action = gameActions[i]; @@ -189,8 +189,8 @@ void DetailItemDelegate::paint_range(QPainter *painter, const QStyleOptionViewIt painter->fillRect(rect, brush); options.text = ""; - options.text += detailViewerModel->tableStrategyModel->cardint2card[cord.first].toFormattedHtml(); - options.text += detailViewerModel->tableStrategyModel->cardint2card[cord.second].toFormattedHtml(); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[cord.first]); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[cord.second]); options.text = "

" + options.text + "<\/h2>"; options.text += QString("

%1<\/h2>").arg(QString::number(range_number,'f',3)); @@ -318,8 +318,8 @@ void DetailItemDelegate::paint_evs(QPainter *painter, const QStyleOptionViewItem } options.text = ""; - options.text += detailViewerModel->tableStrategyModel->cardint2card[card1].toFormattedHtml(); - options.text += detailViewerModel->tableStrategyModel->cardint2card[card2].toFormattedHtml(); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[card1]); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[card2]); options.text = "

" + options.text + "<\/h2>"; for(std::size_t i = 0;i < evs.size();i ++){ GameActions one_action = gameActions[i]; @@ -378,7 +378,7 @@ void DetailItemDelegate::paint_evs_only(QPainter *painter, const QStyleOptionVie if(ind < evs.size() and ind < strategy_number) { float one_ev = evs[ind]; - float normalized_ev = normalization_tanh(detailViewerModel->tableStrategyModel->get_solver()->stack,one_ev); + float normalized_ev = normalization_tanh(detailViewerModel->tableStrategyModel->get_solver()->clt->stack,one_ev); //options.text += QString("
%1").arg(QString::number(normalized_ev)); pair strategy_ui_table = detailViewerModel->tableStrategyModel->ui_strategy_table[this->detailWindowSetting->grid_i][this->detailWindowSetting->grid_j][ind]; @@ -396,8 +396,8 @@ void DetailItemDelegate::paint_evs_only(QPainter *painter, const QStyleOptionVie painter->fillRect(rect, brush); options.text = ""; - options.text += detailViewerModel->tableStrategyModel->cardint2card[card1].toFormattedHtml(); - options.text += detailViewerModel->tableStrategyModel->cardint2card[card2].toFormattedHtml(); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[card1]); + options.text += toFormattedHtml(detailViewerModel->tableStrategyModel->cardint2card[card2]); options.text = "

" + options.text + "<\/h2>"; options.text += QString("

%1<\/h2>").arg(QString::number(one_ev,'f',3)); diff --git a/src/ui/strategyitemdelegate.cpp b/src/ui/strategyitemdelegate.cpp index b2adb6b..7998324 100644 --- a/src/ui/strategyitemdelegate.cpp +++ b/src/ui/strategyitemdelegate.cpp @@ -198,7 +198,7 @@ void StrategyItemDelegate::paint_evs(QPainter *painter, const QStyleOptionViewIt int last_left = 0; for(std::size_t i = 0;i < evs.size();i ++ ){ float one_ev = evs[evs.size() - i - 1]; - float normalized_ev = normalization_tanh(this->qSolverJob->stack,one_ev); + float normalized_ev = normalization_tanh(this->qSolverJob->clt->stack,one_ev); //options.text += QString("
%1").arg(QString::number(normalized_ev)); int red = max((int)(255 - normalized_ev * 255),0); diff --git a/src/ui/treemodel.cpp b/src/ui/treemodel.cpp index 8c94f6c..0835323 100644 --- a/src/ui/treemodel.cpp +++ b/src/ui/treemodel.cpp @@ -120,9 +120,9 @@ void TreeModel::reGenerateTreeItem(GameTreeNode::GameRound round,TreeItem* node_ void TreeModel::setupModelData() { PokerSolver * solver; - if(this->qSolverJob->mode == QSolverJob::Mode::HOLDEM){ + if(this->qSolverJob->mode == PokerMode::HOLDEM){ solver = &(this->qSolverJob->ps_holdem); - }else if(this->qSolverJob->mode == QSolverJob::Mode::SHORTDECK){ + }else if(this->qSolverJob->mode == PokerMode::SHORTDECK){ solver = &(this->qSolverJob->ps_shortdeck); }else{ throw runtime_error("holdem mode incorrect"); diff --git a/strategyexplorer.cpp b/strategyexplorer.cpp index 0f47162..ebfbbd1 100644 --- a/strategyexplorer.cpp +++ b/strategyexplorer.cpp @@ -45,10 +45,10 @@ StrategyExplorer::StrategyExplorer(QWidget *parent,QSolverJob * qSolverJob) : Deck* deck = this->qSolverJob->get_solver()->get_deck(); int index = 0; - QString board_qstring = QString::fromStdString(this->qSolverJob->board); + QString board_qstring = QString::fromStdString(this->qSolverJob->clt->board); for(Card one_card: deck->getCards()){ - if(board_qstring.contains(QString::fromStdString(one_card.toString())))continue; - QString card_str_formatted = QString::fromStdString(one_card.toFormattedString()); + if(board_qstring.contains(QString::fromStdString(one_card.getCard())))continue; + QString card_str_formatted = QString::fromStdString(toFormattedString(one_card)); this->ui->turnCardBox->addItem(card_str_formatted); this->ui->riverCardBox->addItem(card_str_formatted); @@ -120,7 +120,7 @@ void StrategyExplorer::item_expanded(const QModelIndex& index){ } void StrategyExplorer::process_board(TreeItem* treeitem){ - vector board_str_arr = string_split(this->qSolverJob->board,','); + vector board_str_arr = string_split(this->qSolverJob->clt->board,','); vector cards; for(string one_board_str:board_str_arr){ cards.push_back(Card(one_board_str)); @@ -136,7 +136,7 @@ void StrategyExplorer::process_board(TreeItem* treeitem){ cards.push_back(Card(this->tableStrategyModel->getRiverCard())); } } - this->ui->boardLabel->setText(QString("%1: ").arg(tr("board")) + Card::boardCards2html(cards)); + this->ui->boardLabel->setText(QString("%1: ").arg(tr("board")) + boardCards2html(cards)); } void StrategyExplorer::process_treeclick(TreeItem* treeitem){