ι208研究室の (0, 0)
と (13, 12)
の2地点をCNNで分類する、ディープラーニングモデルです。
サンプルは Python 3.10 と PyTorch 2.1.1 によって構築されています。
- GitHubからクローンして、ディレクトリに移動します。
git clone https://github.com/nkzwlab/dhacks-pytorch-cnn.git
cd dhacks-pytorch-cnn
- Pythonのvenvを使用して作業ディレクトリに仮想環境を作成した後、移動します。
- Mac, Linux
python -m venv venv
source venv/bin/activate
- Windows
python -m venv venv
.\venv\Scripts\activate
- 仮想環境の中で、必要なライブラリをインストールします。
pip install -r requirements.txt
- Weights & Biases にログインします。
学習経過のグラフを見ることができるサイトです。
事前に、https://www.wandb.jp/ からW&Bアカウントを作成しておいてください。
wandb login
- 以上で、モデルの学習を開始できます!
python main.py
推論時には、以下のように main.py のpredict
のコメントアウトを外し、model_path の20231119175655
の部分を使用したいモデルのパスに置き換えることで、テストデータへの予測精度を確認できます。
# train()
predict(model_path="outputs/20231119175655/model.pth")