-
Notifications
You must be signed in to change notification settings - Fork 12
/
generate_images.py
113 lines (101 loc) · 3.74 KB
/
generate_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# from torch import autocast
import argparse
import os
import pandas as pd
import torch
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
import wandb
def main(args):
if args.wandb_silent:
os.environ["WANDB_SILENT"] = "true"
device = "cuda"
print(f"Prompts: {args.prompts}")
latents = torch.load(
"applications/Diffusion/generation/latents.pt", map_location=device
)
wandb.init(
project="VisDiff-Diffusion",
group="generated_images",
name=args.prompts[0],
config=vars(args),
)
with open("applications/Diffusion/generation/negative_prompts.txt", "r") as f:
negative_prompts = [line.replace("\n", "") for line in f.readlines()]
negative_prompt = ", ".join(negative_prompts)
print(f"Negative Prompt: {negative_prompt}")
for model_id in args.model_id:
# Use the Euler scheduler here instead
scheduler = EulerDiscreteScheduler.from_pretrained(
model_id, subfolder="scheduler"
)
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
scheduler=scheduler,
torch_dtype=torch.float16,
requires_safety_checker=False,
safety_checker=None,
)
pipe = pipe.to("cuda")
if args.prompts == ["PartiPrompts"]:
parti_prompts = pd.read_csv(
"applications/Diffusion/generation/parti-prompts.csv"
)
prompts = parti_prompts["Prompt"].tolist()
elif args.prompts == ["DiffusionDB"]:
with open("applications/Diffusion/generation/diffusiondb.txt", "r") as f:
prompts = [line.replace("\n", "") for line in f.readlines()]
else:
prompts = args.prompts
for p, prompt in enumerate(prompts):
if os.path.exists(
f'{args.save_dir}/{prompt.replace(" ", "_").replace(".", "")}/{model_id}'
):
continue
print(f"Generating images for prompt: {prompt}")
with torch.autocast("cuda"):
# split up into batches of 5
images = []
step = min([5, args.n])
for i in range(0, args.n, step):
images += pipe(
prompt,
negative_prompt=negative_prompt,
num_images_per_prompt=step,
guidance_scale=7.5,
latents=latents[i : i + step],
).images
for s, i in enumerate(images):
save_dir = f'{args.save_dir}/{prompt.replace(" ", "_").replace(".", "")[:100]}/{model_id}'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
i.save(f"{save_dir}/{s}.png")
wandb.log(
{
f"{model_id}-{prompt}": [
wandb.Image(i) for i in images[: min([20, args.n])]
]
}
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Dataset Understanding")
parser.add_argument("--prompts", type=str, nargs="+", help="prompts")
parser.add_argument(
"--save-dir",
type=str,
required=True,
help="save directory",
)
parser.add_argument("--test", action="store_true", help="test mode")
parser.add_argument(
"--n", type=int, default=50, help="number of images to generate"
)
parser.add_argument("--wandb-silent", action="store_true")
parser.add_argument(
"--model-id",
type=str,
default="CompVis/stable-diffusion-v1-4",
nargs="+",
help="huggingface model id",
)
args = parser.parse_args()
main(args)