Skip to content

Commit

Permalink
Add TI training support
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Jan 26, 2023
1 parent 49bada0 commit 03bd2e9
Show file tree
Hide file tree
Showing 14 changed files with 1,655 additions and 20 deletions.
138 changes: 138 additions & 0 deletions README-ja.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
## リポジトリについて
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。

[README in English](./README.md) ←更新情報はこちらにあります

GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。

以下のスクリプトがあります。

* DreamBooth、U-NetおよびText Encoderの学習をサポート
* fine-tuning、同上
* 画像生成
* モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換)

## 使用法について

当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください(将来的にはすべてこちらへ移すかもしれません)。

* [DreamBoothの学習について](./train_db_README-ja.md)
* [fine-tuningのガイド](./fine_tune_README_ja.md):
BLIPによるキャプショニングと、DeepDanbooruまたはWD14 taggerによるタグ付けを含みます
* [LoRAの学習について](./train_network_README-ja.md)
* [Textual Inversionの学習について](./train_ti_README-ja.md)
* note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)

## Windowsでの動作に必要なプログラム

Python 3.10.6およびGitが必要です。

- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
- git: https://git-scm.com/download/win

PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。
(venvに限らずスクリプトの実行が可能になりますので注意してください。)

- PowerShellを管理者として開きます。
- 「Set-ExecutionPolicy Unrestricted」と入力し、Yと答えます。
- 管理者のPowerShellを閉じます。

## Windows環境でのインストール

以下の例ではPyTorchは1.12.1/CUDA 11.6版をインストールします。CUDA 11.3版やPyTorch 1.13を使う場合は適宜書き換えください。

(なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。)

通常の(管理者ではない)PowerShellを開き以下を順に実行します。

```powershell
git clone https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install --upgrade -r requirements.txt
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
accelerate config
```

コマンドプロンプトでは以下になります。


```bat
git clone https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install --upgrade -r requirements.txt
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
accelerate config
```

(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。)

accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。)

※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます(……)。数字キーの0、1、2……で選択できますので、そちらを使ってください。

```txt
- This machine
- No distributed training
- NO
- NO
- NO
- all
- fp16
```

※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)

### PyTorchとxformersのバージョンについて

他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください。

## アップグレード

新しいリリースがあった場合、以下のコマンドで更新できます。

```powershell
cd sd-scripts
git pull
.\venv\Scripts\activate
pip install --upgrade -r <requirement file name>
```

コマンドが成功すれば新しいバージョンが使用できます。

## 謝意

LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。

## ライセンス

スクリプトのライセンスはASL 2.0ですが(Diffusersおよびcloneofsimo氏のリポジトリ由来のものも同様)、一部他のライセンスのコードを含みます。

[Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT

[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT

[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause


10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ This repository repository is providing a Gradio GUI for kohya's Stable Diffusio

Python 3.10.6+ and Git:

- Python 3.10.6+: https://www.python.org/ftp/python/3.10.9/python-3.10.9-amd64.exe
- Install Python 3.10 using https://www.python.org/ftp/python/3.10.9/python-3.10.9-amd64.exe (make sure to tick the box to add Python to the environment path)
- git: https://git-scm.com/download/win
- Visual Studio 2015, 2017, 2019, and 2022 redistributable: https://aka.ms/vs/17/release/vc_redist.x64.exe

## Installation

Expand All @@ -23,7 +24,7 @@ Open a regular user Powershell terminal and type the following inside:
git clone https://github.com/bmaltais/kohya_ss.git
cd kohya_ss
python -m venv --system-site-packages venv
python -m venv venv
.\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
Expand All @@ -40,7 +41,7 @@ accelerate config

### Optional: CUDNN 8.6

This step is optional but can improve the learning speed for NVidia 4090 owners...
This step is optional but can improve the learning speed for NVidia 30X0/40X0 owners... It allows larger training batch size and faster training speed

Due to the filesize I can't host the DLLs needed for CUDNN 8.6 on Github, I strongly advise you download them for a speed boost in sample generation (almost 50% on 4090) you can download them from here: https://b1.thefileditch.ch/mwxKTEtelILoIbMbruuM.zip

Expand Down Expand Up @@ -130,6 +131,9 @@ Then redo the installation instruction within the kohya_ss venv.

## Change history

* 2023/01/26 (v20.5.0):
- Add new `Dreambooth TI` tab for training of Textual Inversion embeddings
- Add Textual Inversion training. Documentation is [here](./train_ti_README-ja.md) (in Japanese.)
* 2023/01/22 (v20.4.1):
- Add new tool to verify LoRA weights produced by the trainer. Can be found under "Dreambooth LoRA/Tools/Verify LoRA"
* 2023/01/22 (v20.4.0):
Expand Down
3 changes: 3 additions & 0 deletions kohya_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
from dreambooth_gui import dreambooth_tab
from finetune_gui import finetune_tab
from textual_inversion_gui import ti_tab
from library.utilities import utilities_tab
from library.extract_lora_gui import gradio_extract_lora_tab
from library.merge_lora_gui import gradio_merge_lora_tab
Expand Down Expand Up @@ -30,6 +31,8 @@ def UI(username, password):
) = dreambooth_tab()
with gr.Tab('Dreambooth LoRA'):
lora_tab()
with gr.Tab('Dreambooth TI'):
ti_tab()
with gr.Tab('Finetune'):
finetune_tab()
with gr.Tab('Utilities'):
Expand Down
4 changes: 2 additions & 2 deletions library/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l
minimum=1,
maximum=os.cpu_count(),
step=1,
label='Number of CPU threads per process',
value=os.cpu_count(),
label='Number of CPU threads per core',
value=2,
)
seed = gr.Textbox(label='Seed', value=1234)
with gr.Row():
Expand Down
80 changes: 76 additions & 4 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import random
import hashlib
from io import BytesIO

from tqdm import tqdm
import torch
Expand All @@ -25,6 +26,7 @@
import cv2
from einops import rearrange
from torch import einsum
import safetensors.torch

import library.model_util as model_util

Expand Down Expand Up @@ -85,6 +87,7 @@ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_to
self.enable_bucket = False
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_info = None

self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2

Expand All @@ -110,9 +113,14 @@ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_to

self.image_data: dict[str, ImageInfo] = {}

self.replacements = {}

def disable_token_padding(self):
self.token_padding_disabled = True

def add_replacement(self, str_from, str_to):
self.replacements[str_from] = str_to

def process_caption(self, caption):
if self.shuffle_caption:
tokens = caption.strip().split(",")
Expand All @@ -125,6 +133,17 @@ def process_caption(self, caption):
random.shuffle(tokens)
tokens = keep_tokens + tokens
caption = ",".join(tokens).strip()

for str_from, str_to in self.replacements.items():
if str_from == "":
# replace all
if type(str_to) == list:
caption = random.choice(str_to)
else:
caption = str_to
else:
caption = caption.replace(str_from, str_to)

return caption

def get_input_ids(self, caption):
Expand Down Expand Up @@ -217,11 +236,17 @@ def make_buckets(self):
self.buckets[bucket_index].append(image_info.image_key)

if self.enable_bucket:
self.bucket_info = {"buckets": {}}
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)}
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")

img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}")
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
print(f"mean ar error (without repeats): {mean_img_ar_error}")


# 参照用indexを作る
self.buckets_indices: list(BucketBatchIndex) = []
Expand Down Expand Up @@ -599,7 +624,7 @@ def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_to
else:
# わりといい加減だがいい方法が思いつかん
abs_path = glob_images(train_data_dir, image_key)
assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}"
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
abs_path = abs_path[0]

caption = img_md.get('caption')
Expand Down Expand Up @@ -706,15 +731,17 @@ def image_key_to_npz_file(self, image_key):
return npz_file_norm, npz_file_flip


def debug_dataset(train_dataset):
def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します")
k = 0
for example in train_dataset:
if example['latents'] is not None:
print("sample has latents from npz file")
for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])):
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}')
if show_input_ids:
print(f"input ids: {iid}")
if example['images'] is not None:
im = example['images'][j]
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
Expand Down Expand Up @@ -790,6 +817,49 @@ def calculate_sha256(filename):
return hash_sha256.hexdigest()


def precalculate_safetensors_hashes(tensors, metadata):
"""Precalculate the model hashes needed by sd-webui-additional-networks to
save time on indexing the model later."""

# Because writing user metadata to the file can change the result of
# sd_models.model_hash(), only retain the training metadata for purposes of
# calculating the hash, as they are meant to be immutable
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}

bytes = safetensors.torch.save(tensors, metadata)
b = BytesIO(bytes)

model_hash = addnet_hash_safetensors(b)
legacy_hash = addnet_hash_legacy(b)
return model_hash, legacy_hash


def addnet_hash_legacy(b):
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
m = hashlib.sha256()

b.seek(0x100000)
m.update(b.read(0x10000))
return m.hexdigest()[0:8]


def addnet_hash_safetensors(b):
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024

b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")

offset = n + 8
b.seek(offset)
for chunk in iter(lambda: b.read(blksize), b""):
hash_sha256.update(chunk)

return hash_sha256.hexdigest()


# flash attention forwards and backwards

# https://arxiv.org/abs/2205.14135
Expand Down Expand Up @@ -1057,6 +1127,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)")
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
Expand Down
3 changes: 3 additions & 0 deletions lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ def train_model(
msgbox('Output folder path is missing')
return

if not os.path.exists(output_dir):
os.makedirs(output_dir)

if stop_text_encoder_training_pct > 0:
msgbox('Output "stop text encoder training" is not yet supported. Ignoring')
stop_text_encoder_training_pct = 0
Expand Down
Loading

0 comments on commit 03bd2e9

Please sign in to comment.