This repository has been archived by the owner on May 14, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 306
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from Linaqruf/experimental
Publish Kohya Trainer V8
- Loading branch information
Showing
23 changed files
with
6,027 additions
and
2,218 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
93 changes: 93 additions & 0 deletions
93
convert_diffusers20_original_sd/convert_diffusers20_original_sd.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# convert Diffusers v1.x/v2.0 model to original Stable Diffusion | ||
# v1: initial version | ||
# v2: support safetensors | ||
# v3: fix to support another format | ||
# v4: support safetensors in Diffusers | ||
|
||
import argparse | ||
import os | ||
import torch | ||
from diffusers import StableDiffusionPipeline | ||
|
||
import model_util | ||
|
||
|
||
def convert(args): | ||
# 引数を確認する | ||
load_dtype = torch.float16 if args.fp16 else None | ||
|
||
save_dtype = None | ||
if args.fp16: | ||
save_dtype = torch.float16 | ||
elif args.bf16: | ||
save_dtype = torch.bfloat16 | ||
elif args.float: | ||
save_dtype = torch.float | ||
|
||
is_load_ckpt = os.path.isfile(args.model_to_load) | ||
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 | ||
|
||
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" | ||
assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" | ||
|
||
# モデルを読み込む | ||
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) | ||
print(f"loading {msg}: {args.model_to_load}") | ||
|
||
if is_load_ckpt: | ||
v2_model = args.v2 | ||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load) | ||
else: | ||
pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None) | ||
text_encoder = pipe.text_encoder | ||
vae = pipe.vae | ||
unet = pipe.unet | ||
|
||
if args.v1 == args.v2: | ||
# 自動判定する | ||
v2_model = unet.config.cross_attention_dim == 1024 | ||
print("checking model version: model is " + ('v2' if v2_model else 'v1')) | ||
else: | ||
v2_model = args.v1 | ||
|
||
# 変換して保存する | ||
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" | ||
print(f"converting and saving as {msg}: {args.model_to_save}") | ||
|
||
if is_save_ckpt: | ||
original_model = args.model_to_load if is_load_ckpt else None | ||
key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet, | ||
original_model, args.epoch, args.global_step, save_dtype, vae) | ||
print(f"model saved. total converted state_dict keys: {key_count}") | ||
else: | ||
print(f"copy scheduler/tokenizer config from: {args.reference_model}") | ||
model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors) | ||
print(f"model saved.") | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--v1", action='store_true', | ||
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む') | ||
parser.add_argument("--v2", action='store_true', | ||
help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む') | ||
parser.add_argument("--fp16", action='store_true', | ||
help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)') | ||
parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)') | ||
parser.add_argument("--float", action='store_true', | ||
help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)') | ||
parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値') | ||
parser.add_argument("--global_step", type=int, default=0, | ||
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値') | ||
parser.add_argument("--reference_model", type=str, default=None, | ||
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要") | ||
parser.add_argument("--use_safetensors", action='store_true', | ||
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)") | ||
|
||
parser.add_argument("model_to_load", type=str, default=None, | ||
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ") | ||
parser.add_argument("model_to_save", type=str, default=None, | ||
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存") | ||
|
||
args = parser.parse_args() | ||
convert(args) |
Oops, something went wrong.