From acd1ef80c2452970aa694811fc231dd8969b9c99 Mon Sep 17 00:00:00 2001 From: yaneurao Date: Tue, 17 Dec 2024 21:05:06 +0900 Subject: [PATCH] =?UTF-8?q?-=20PolicyBook=E3=81=AE=E3=82=B3=E3=83=BC?= =?UTF-8?q?=E3=83=89=E3=80=81GitHub=E3=81=AB=E5=8F=8D=E6=98=A0=E5=BF=98?= =?UTF-8?q?=E3=82=8C=E3=81=A6=E3=81=84=E3=81=9F=E3=81=A8=E3=81=93=E3=82=8D?= =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E3=80=82=20-=20vector=E3=81=AB=E5=AF=BE?= =?UTF-8?q?=E3=81=97=E3=81=A6&v=E3=81=BF=E3=81=9F=E3=81=84=E3=81=AB?= =?UTF-8?q?=E3=81=97=E3=81=A6=E3=82=A2=E3=83=89=E3=83=AC=E3=82=B9=E3=82=92?= =?UTF-8?q?=E5=8F=96=E3=81=A3=E3=81=A6=E3=81=84=E3=81=9F=E3=81=A8=E3=81=93?= =?UTF-8?q?=E3=82=8D=E3=80=81v.data()=E3=82=92=E4=BD=BF=E3=81=86=E3=82=88?= =?UTF-8?q?=E3=81=86=E3=81=AB=E4=BF=AE=E6=AD=A3=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- source/book/book.cpp | 17 +++++----- source/book/book.h | 6 +++- source/book/policybook.cpp | 63 ++++++++++++++++++++++++++++++++++--- source/book/policybook.h | 9 ++++-- source/eval/evaluate_io.cpp | 4 +-- source/learn/learner.cpp | 4 +-- source/misc.cpp | 2 +- 7 files changed, 85 insertions(+), 20 deletions(-) diff --git a/source/book/book.cpp b/source/book/book.cpp index 0d38d89fa..b59a0a795 100644 --- a/source/book/book.cpp +++ b/source/book/book.cpp @@ -1492,7 +1492,7 @@ namespace Book Position pos; string root_sfen = "startpos moves 7g7f 3c3d 6g6f 8b3b 8h7g 5a6b 2h8h 6b7b 8g8f 3d3e 8f8e 3e3f 3i2h 3f3g+ 2h3g 3a4b 4i3h P*3f 3g2h 4b3c 6i5h 3c4d 7i6h 1c1d P*3g 7a8b"; deque si; - BookTools::feed_position_string(pos, root_sfen, si, [](Position&){}); + BookTools::feed_position_string(pos, root_sfen, si, [](Position&,Move){}); string moves1 = "1g1f 2g2f 3g3f 4g4f 5g5f 6f6e 7f7e 8e8d 9g9f 1i1h 9i9h 2h3i 6h6g 6h7i 7g8f 7g9e 8h7h 8h8f 8h8g 8h9h 3h3i 3h4h 5h4h 5h6g 5i4h 5i4i 5i6i"; string moves2 = string(); @@ -1518,7 +1518,8 @@ namespace BookTools // "sfen xxx moves yyy ..." // また、局面を1つ進めるごとにposition_callback関数が呼び出される。 // 辿った局面すべてに対して何かを行いたい場合は、これを利用すると良い。 - void feed_position_string(Position& pos, const std::string& root_sfen, std::deque& si, const std::function& position_callback) + void feed_position_string(Position& pos, const std::string& root_sfen, std::deque& si, + const std::function& position_callback) { // issから次のtokenを取得する auto feed_next = [](Parser::LineScanner& iss) @@ -1576,9 +1577,6 @@ namespace BookTools } } while (token == "startpos" || token == "sfen" || token == "moves"/* movesは無視してループを回る*/ ); - // callbackを呼び出してやる。 - position_callback(pos); - // moves以降は1手ずつ進める while (token != "") { @@ -1591,14 +1589,17 @@ namespace BookTools if (!move.is_ok()) break; + // callbackを呼び出してやる。 + position_callback(pos, move); + si.emplace_back(StateInfo()); pos.do_move(move, si.back()); - // callbackを呼び出してやる。 - position_callback(pos); - token = feed_next(iss); } + + // 最後の局面でcallbackを呼び出してやる。 + position_callback(pos, Move::none()); } // 平手、駒落ちの開始局面集 diff --git a/source/book/book.h b/source/book/book.h index 73be5b8c5..0a9f3d191 100644 --- a/source/book/book.h +++ b/source/book/book.h @@ -333,7 +333,11 @@ namespace BookTools // "sfen xxx moves yyy ..." // また、局面を1つ進めるごとにposition_callback関数が呼び出される。 // 辿った局面すべてに対して何かを行いたい場合は、これを利用すると良い。 - void feed_position_string(Position& pos, const std::string& root_sfen, std::deque& si, const std::function& position_callback = [](Position&) {}); + // + // position_callbackは、その局面と、その局面での指し手が引数にセットされて呼び出される。 + // 与えたsfenの最後の局面では、MoveはMove::none()が入って呼び出される。 + void feed_position_string(Position& pos, const std::string& root_sfen, std::deque& si, + const std::function& position_callback = [](Position&, Move) {}); // 平手、駒落ちの開始局面集 // ここで返ってきた配列の、[0]は平手のsfenであることは保証されている。 diff --git a/source/book/policybook.cpp b/source/book/policybook.cpp index 5d2b9de1d..8e395ca52 100644 --- a/source/book/policybook.cpp +++ b/source/book/policybook.cpp @@ -9,7 +9,7 @@ #include "../position.h" #include "../thread.h" #include "../usi.h" - +#include "../book/book.h" // freqの和がUINT16_MAXに収まるようにする。 u16 MoveFreq32Record::overflow_check() @@ -61,8 +61,8 @@ Tools::Result PolicyBook::read_book_db(std::string path) reader.ReadLine(sfen); if (sfen != "#YANEURAOU-POLICY-DB2024 1.00") { - sync_cout << "info string Error! policy book header" << sync_endl; - Tools::exit(); + sync_cout << "info string Error! invalid policy book header" << sync_endl; + return Tools::ResultCode::FileMismatch; } while (true) { @@ -161,6 +161,8 @@ Tools::Result PolicyBook::read_book_db_bin(std::string path) // PolicyBookを読み込み、."db.bin"ファイルを書き出す。 Tools::Result PolicyBook::read_book() { + +#if !defined(ENABLE_POLICY_BOOK_LEARN) // まだ読み込んでいないならば.. if (!is_loaded()) { @@ -177,7 +179,35 @@ Tools::Result PolicyBook::read_book() return result; } - return Tools::Result::Ok(); // 読み込めたことにしておく。 +#else + // ただし、ENABLE_POLICY_BOOK_LEARNが定義されているときは、毎回読み込む。(局後学習データがあるため) + + // binary化されたPolicyBookがあるなら、それを読み込む。 + Tools::Result result = read_book_db_bin(); + if (result.is_not_ok()) + { + result = read_book_db(); + if (result.is_ok()) + { + // "db.bin"形式で書き出しておく。(次回の読み込み高速化のため) + result = write_book_db_bin(); + } + } + + // そもそも読み込んでいないのでmerge不要。 + if (result.is_not_ok()) + result = read_book_db_bin(POLICY_BOOK_LEARN_DB_BIN_NAME); + else { + PolicyBook pb; + result = pb.read_book_db_bin(POLICY_BOOK_LEARN_DB_BIN_NAME); + // 読み込みに成功したのでmergeする。 + if (result.is_ok()) + merge_book(pb); + } + +#endif + + return result; // 読み込めたことにしておく。 } @@ -309,6 +339,31 @@ PolicyBookEntry* PolicyBook::probe_policy_book(HASH_KEY key) return (it != book_body.end() && it->key == key) ? &*it : nullptr; } +// "position "コマンドのposition以降の文字列を渡して、それを +// POLICY_BOOK_LEARN_DB_BIN_NAMEにappendで書き出す。 +void PolicyBook::append_sfen_to_db_bin(const std::string& sfen) +{ + Position pos; + StateList si; + std::vector entries; + + BookTools::feed_position_string(pos, sfen, si, [&](Position& p, Move m) { + // 最後の局面は、m==Move::none()が入ってくる。 + if (m == Move::none()) + return; + PolicyBookEntry entry; + entry.key = p.hash_key(); + entry.move_freq[0] = MoveFreq(m.to_move16(), 1); + entries.push_back(entry); + }); + + // ファイルにappendする。 + SystemIO::BinaryWriter writer; + writer.Open(POLICY_BOOK_LEARN_DB_BIN_NAME, true); + auto result = writer.Write(entries.data(), sizeof(PolicyBookEntry) * entries.size()); + sync_cout << "info string append " << POLICY_BOOK_LEARN_DB_BIN_NAME << ". status = " << result.to_string() << sync_endl; +} + #if 0 // PolicyBookのmergeが正常にできているかをテストするコード。 void merge_test() diff --git a/source/book/policybook.h b/source/book/policybook.h index 81f1d0b58..7227014af 100644 --- a/source/book/policybook.h +++ b/source/book/policybook.h @@ -12,8 +12,9 @@ static_assert(HASH_KEY_BITS == 128 , "HASH_KEY_BITS must be 128"); -#define POLICY_BOOK_DB_NAME "eval/policy_book.db" -#define POLICY_BOOK_DB_BIN_NAME "eval/policy_book.db.bin" +#define POLICY_BOOK_DB_NAME "eval/policy_book.db" +#define POLICY_BOOK_DB_BIN_NAME "eval/policy_book.db.bin" +#define POLICY_BOOK_LEARN_DB_BIN_NAME "eval/policy_book-learn.db.bin" // ============================================================ // Policy Book @@ -84,6 +85,10 @@ class PolicyBook // PolicyBook同士のmerge void PolicyBook::merge_book(const PolicyBook& book); + // "position "コマンドのposition以降の文字列を渡して、それを + // POLICY_BOOK_LEARN_DB_BIN_NAMEにappendで書き出す。 + void append_sfen_to_db_bin(const std::string& sfen); + // ファイルから読み込んだか? bool is_loaded() const { return book_body.size() != 0; } diff --git a/source/eval/evaluate_io.cpp b/source/eval/evaluate_io.cpp index ebfcc98d0..dd587e3e4 100644 --- a/source/eval/evaluate_io.cpp +++ b/source/eval/evaluate_io.cpp @@ -111,14 +111,14 @@ namespace EvalIO { std::vector buffer(input_block_size); std::ifstream ifs(in_.file_or_memory.filename, std::ios::binary); - if (ifs) ifs.read(reinterpret_cast(&buffer[0]), input_block_size); + if (ifs) ifs.read(reinterpret_cast(buffer.data()), input_block_size); else { std::cout << "info string read file error , file = " << in_.file_or_memory.filename << std::endl; return false; }; std::ofstream ofs(out_.file_or_memory.filename, std::ios::binary); - if (ofs) ofs.write(reinterpret_cast(&buffer[0]), output_block_size); + if (ofs) ofs.write(reinterpret_cast(buffer.data()), output_block_size); else { std::cout << "info string write file error , file = " << out_.file_or_memory.filename << std::endl; diff --git a/source/learn/learner.cpp b/source/learn/learner.cpp index f35a498ab..94f56585d 100644 --- a/source/learn/learner.cpp +++ b/source/learn/learner.cpp @@ -218,7 +218,7 @@ struct SfenWriter { for (auto ptr : buffers) { - fs.write((const char*)&((*ptr)[0]), sizeof(PackedSfenValue) * ptr->size()); + fs.write(reinterpret_cast(ptr->data()), sizeof(PackedSfenValue) * ptr->size()); sfen_write_count += ptr->size(); @@ -2314,7 +2314,7 @@ void shuffle_files(const vector& filenames , const string& output_file_n // ファイルに書き出す fstream fs; fs.open(make_filename(write_file_count++), ios::out | ios::binary); - fs.write((char*)&buf[0], size * sizeof(PackedSfenValue)); + fs.write((char*)buf.data(), size * sizeof(PackedSfenValue)); fs.close(); a_count.push_back(size); diff --git a/source/misc.cpp b/source/misc.cpp index 7b1be214a..b88f9eb6f 100644 --- a/source/misc.cpp +++ b/source/misc.cpp @@ -1377,7 +1377,7 @@ namespace SystemIO // 今回のループで書き込むbyte数 write_size = buf_size - write_cursor; std::memcpy(&buf[write_cursor], ptr2, write_size); - if (fwrite(&buf[0], buf_size, 1, fp) == 0) + if (fwrite(buf.data(), buf_size, 1, fp) == 0) return Tools::ResultCode::FileWriteError; // buf[0..write_cursor-1]が窓で、ループごとにその窓がbuf_sizeずつずれていくと考える。