-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
41 changed files
with
1,911 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# Project specific folders | ||
lightning_logs/ | ||
tf_logs/ | ||
remote_logs/ | ||
local_logs/ | ||
*_logs/ | ||
checkpoints/ | ||
dataset/ | ||
local_misc/ | ||
notebooks/data/ | ||
|
||
# LSF logfiles | ||
lsf.* | ||
|
||
# IDEA folders | ||
.idea/ | ||
|
||
/work_dirs/* | ||
/data/* | ||
!.gitkeep |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
PROJECT ?= bert-mcts | ||
DATADIR ?= ${PWD}/data | ||
WORKSPACE ?= /workspace/$(PROJECT) | ||
DOCKER_IMAGE ?= ${PROJECT}:latest | ||
|
||
SHMSIZE ?= 100G | ||
DOCKER_OPTS := \ | ||
--name ${PROJECT} \ | ||
--rm -it \ | ||
--shm-size=${SHMSIZE} \ | ||
-v ${PWD}:${WORKSPACE} \ | ||
-v ${DATADIR}:${WORKSPACE}/data \ | ||
-v ${LOG_DIR}:${WORKSPACE}/work_dirs/logs \ | ||
-w ${WORKSPACE} \ | ||
--ipc=host \ | ||
--network=host \ | ||
--gpus all | ||
|
||
docker-build: | ||
docker build -f docker/Dockerfile -t ${DOCKER_IMAGE} . | ||
|
||
docker-start-interactive: docker-build | ||
docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} bash | ||
|
||
docker-start-jupyter: docker-build | ||
docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \ | ||
bash -c "jupyter lab --port=8888 --ip=0.0.0.0 --allow-root --no-browser" | ||
|
||
docker-run: docker-build | ||
docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \ | ||
bash -c "${COMMAND}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,101 @@ | ||
# bert-mcts-youtube | ||
# BERT-MCTS-YOUTUBE | ||
|
||
YouTubeにてヨビノリたくみさんと対戦した将棋ソフトです。 | ||
自然言語モデルであるBERTとモンテカルロ木探索(MCTS)の組み合わせで出来ています。 | ||
すべてpythonで書いてあるため、探索の速度は遅いです。 | ||
|
||
BERT以外の大部分は『将棋AIで学ぶディープラーニング』を参考に書いています。 | ||
- [書籍(amazon)](https://www.amazon.co.jp/dp/B07B7JJ929) | ||
- [github](https://github.com/TadaoYamaoka/python-dlshogi) | ||
|
||
## 環境 | ||
|
||
### Colab | ||
|
||
テストするだけなら[google colab](https://colab.research.google.com/drive/10KAuLlNe6FKZBp_iE2bQJPNhoY2WeACx?usp=sharing) が簡単です。 | ||
|
||
以下はローカルで試す場合。CPUだと遅いのでCUDA環境が望ましいです。 | ||
重みファイルは[ここ](https://drive.google.com/drive/folders/1N-Np2NmNLtLGS9gjnreBkYdTxrH1EHFw?usp=sharing) にアップしてあり、 | ||
たくみさんと戦った重みファイルがyoutube_version.ckpt、追加で数日間学習させた重みファイルがlatest.ckptになります。 | ||
ダウンロード先のパスはengine/***_player.sh内で指定してください。 | ||
デフォルトではwork_dirs以下にダウンロードすることを想定しています。 | ||
|
||
### Docker | ||
|
||
cuda10.2以上のnvidia-dockerが整っているなら次のコマンドで環境に入れます。 | ||
```bash | ||
$ make docker-start-interactive | ||
``` | ||
|
||
### Ubuntu18.04 | ||
|
||
cuda10.2でanacondaが入っていれば次のコマンドで仮想環境を作れます。 | ||
```bash | ||
$ conda env create -f env_name.yml | ||
$ conda activate bert-mcts-youtube | ||
$ python setup.py develop | ||
``` | ||
|
||
### Windows10 | ||
|
||
未検証 | ||
|
||
## 将棋エンジンのテスト | ||
|
||
エンジンはengineディレクトリに用意しています。これらはShogiGUIなどから呼び出すことができます。 | ||
- policy_player.shはBERTの出力する方策のみを頼りに指すモデル(弱い) | ||
- mcts_player.shはBERTの出力をもとにMCTSで探索するモデル | ||
|
||
## 学習 | ||
|
||
学習には互角局面集とGCTの自己対戦棋譜を用いました。 | ||
モデルはMasked Language Modelで事前学習してから、Policy Value Networkの学習という手順を踏みます。 | ||
ただし、将棋は良質な教師データが大量にあるため事前学習の効果はあまりない気がします。 | ||
|
||
### データの準備 | ||
|
||
互角局面集のダウンロード | ||
```bash | ||
$ cd data | ||
$ git clone https://github.com/tttak/ShogiGokakuKyokumen.git | ||
``` | ||
|
||
GCTの自己対戦棋譜 | ||
```bash | ||
$ cd data | ||
$ mkdir hcpe | ||
``` | ||
|
||
GCTの自己対戦棋譜は開発者の加納さんが[リンク](https://drive.google.com/drive/folders/14FaqqIHRctTQIY6hScCFXWQQZ_pSU3-F) | ||
に公開してくださっていまし。 | ||
ここからselfplay-***となっているファイルをいくつかdata/hcpe以下にダウンロードしてください。 | ||
サイズが大きいので一個でも十分な量あります。 | ||
|
||
これらを準備できたら以下のコマンドでデータセットを作ります。 | ||
|
||
```bash | ||
$ python tools/make_dataset.py | ||
``` | ||
|
||
### Masked Language Modelの学習 | ||
|
||
```bash | ||
$ python tools/train.py configs/mlm_base.yaml | ||
``` | ||
|
||
### 重みファイルの変換 | ||
|
||
Masked Language Modelのチェックポイントをtransformers形式に変換しておきます。 | ||
これによって転移学習のコードが多少書きやすくなります。 | ||
|
||
```bash | ||
$ python tools/pl_to_transformers.py work_dirs/mlm_base/version_0/checkpoints/last.ckpt | ||
``` | ||
|
||
### Policy Value Modelの学習 | ||
|
||
最後にこれらを使ってPolicy Valueを学習させます。 | ||
|
||
```bash | ||
$ python tools/train.py configs/policy_value.yaml | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
model_type: 'MLM' | ||
|
||
seed: 42 | ||
dataset_dir: './data/dataset/gokaku_100' | ||
model_dir: | ||
|
||
train_loader: | ||
batch_size: 64 | ||
shuffle: True | ||
num_workers: 8 | ||
pin_memory: False | ||
drop_last: True | ||
|
||
val_loader: | ||
batch_size: 64 | ||
shuffle: False | ||
num_workers: 8 | ||
pin_memory: False | ||
drop_last: False | ||
|
||
train_params: | ||
max_epochs: 5 | ||
# validationおよびcheckpointの間隔step数 | ||
val_check_interval: 3000 | ||
# 環境に応じて,適宜変更 | ||
gpus: [0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
model_type: 'PolicyValue' | ||
|
||
seed: 42 | ||
dataset_dir: './data/dataset/selfplay' | ||
model_dir: './work_dirs/mlm_base/version_0/checkpoints' | ||
|
||
train_loader: | ||
batch_size: 128 | ||
shuffle: True | ||
num_workers: 0 | ||
pin_memory: True | ||
drop_last: True | ||
|
||
val_loader: | ||
batch_size: 128 | ||
shuffle: False | ||
num_workers: 0 | ||
pin_memory: True | ||
drop_last: False | ||
|
||
train_params: | ||
max_epochs: 1 | ||
# validationおよびcheckpointの間隔step数 | ||
val_check_interval: 30000 | ||
limit_val_batches: 0.1 | ||
# 環境に応じて,適宜変更 | ||
gpus: [0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime | ||
ENV DEBIAN_FRONTEND=noninteractive | ||
ADD requirements.txt /tmp | ||
RUN pip install -r /tmp/requirements.txt | ||
|
||
ADD docker/entrypoint.sh /tmp | ||
ENTRYPOINT ["bash", "/tmp/entrypoint.sh"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
python setup.py develop | ||
exec "$@" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#!/bin/sh | ||
python -m src.player.mcts_player --ckpt_path ./work_dirs/youtube_version.ckpt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#!/bin/sh | ||
python -m src.player.policy_player --ckpt_path ./work_dirs/youtube_version.ckpt |
Oops, something went wrong.