-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest_dreambooth.py
116 lines (103 loc) · 4.18 KB
/
test_dreambooth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import time
from pathlib import Path
import argparse
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a testing script.")
parser.add_argument(
"--pred_path",
type=str,
default=None,
required=True,
help="Path to save generate images.",
)
parser.add_argument(
"--model_path",
type=str,
default=None,
required=True,
help="Path to model.",
)
parser.add_argument(
"--token",
type=str,
default="aabbccddeeffgg",
help="Special token.",
)
parser.add_argument(
"--class_str",
type=str,
default="person",
help="Class string to help the model understand what class the token is.",
)
parser.add_argument(
"--tests",
type=str,
default="all",
help="list of test ids. default all uses all the test prompts",
)
parser.add_argument(
"--num_pred_steps",
type=int,
default=75,
help="Number of steps for inference.",
)
parser.add_argument(
"--guide",
type=float,
default=7.5,
help="Guide power.",
)
parser.add_argument(
"--num_preds",
type=int,
default=1,
help="Number of predictions for each prompt.",
)
parser.add_argument(
'--ddim',
action='store_true',
help="Flag to activate the DDIM scheduler.",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
token_class_str = args.token + " " + args.class_str
tests = {
"1": ["photo, colorful cinematic portrait of " + token_class_str + ", armor, cyberpunk,background made of brain cells, back light, organic, art by greg rutkowski, ultrarealistic, leica 30mm", args.num_pred_steps, args.guide, "rutkowski"],
"2": ["pencil sketch portrait of " + token_class_str + " inpired by greg rutkowski, digital art by artgem", args.num_pred_steps, args.guide, "rutkowskiartgem"],
"3": ["photo,colorful cinematic portrait of " + token_class_str + ", " + token_class_str + " with long hair, color lights, on stage, ultrarealistic", args.num_pred_steps, args.guide, "longhair"],
"4": ["photo portrait of " + token_class_str + " astronaut, astronaut, helmet in alien world abstract oil painting, greg rutkowski, detailed face", args.num_pred_steps, args.guide, "astronautrutkowski"],
"5": ["photo portrait of " + token_class_str + " as firefighter, helmet, ultrarealistic, leica 30mm", args.num_pred_steps, args.guide, "firefighter"],
"6": ["photo portrait of " + token_class_str + " as steampunk warrior, neon organic vines, digital painting", args.num_pred_steps, args.guide, "steampunk"],
"7": ["impressionist portrait painting of " + token_class_str + " by Daniel F Gerhartz, (( " + token_class_str + " with painted in an impressionist style)), nature, trees", args.num_pred_steps, args.guide, "danielgerhartz"],
}
if args.ddim:
ddim_scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
pipe = StableDiffusionPipeline.from_pretrained(
args.model_path,
torch_dtype=torch.float16,
scheduler=ddim_scheduler,
safety_checker=None).to("cuda")
else:
pipe = StableDiffusionPipeline.from_pretrained(
args.model_path,
torch_dtype=torch.float16,
safety_checker=None).to("cuda")
Path(args.pred_path).mkdir(parents=True, exist_ok=True)
for i in range(args.num_preds):
for key in tests:
image = pipe(tests[key][0], num_inference_steps=tests[key][1], guidance_scale=tests[key][2]).images[0]
timestr = time.strftime("%Y%m%d-%H%M%S")
image.save(args.pred_path + "/" + key + "-" + tests[key][3] + "-" + timestr + ".png")
if __name__ == "__main__":
main()