-
Notifications
You must be signed in to change notification settings - Fork 0
/
generator.py
86 lines (65 loc) · 2.67 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from huggingface_hub import notebook_login
from datetime import datetime
import os
import random
import numpy as np
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
if __name__ == "__main__":
def dummy(images, **kwargs):
return images, False
# these images are in CIFAR_10/synthetic/images/ there should be around 165 images
# torch.manual_seed(0)
# random.seed(0)
# np.random.seed(0)
torch.manual_seed(1)
random.seed(1)
np.random.seed(1)
lms = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear"
)
acc_tok = "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"
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
scheduler=lms,
use_auth_token=acc_tok,
).to("cuda")
pipe.safety_checker = dummy
CLASS_NAMES = ['airplane', 'automobile', 'bird', 'cat', 'deer']
prompts = []
fr = open('gpt_3_prompts1.txt','r')
for fl in fr:
prompts += fl.strip().split('\n')
# print(prompts)
# print(len(prompts))
n_predictions = 1200
# if not os.path.exists("CIFAR_10/synthetic"):
# os.mkdir("CIFAR_10/synthetic")
# if not os.path.exists("CIFAR_10/synthetic/images"):
# os.mkdir("CIFAR_10/synthetic/images")
# for tmp in CLASS_NAMES:
# if not os.path.exists("CIFAR_10/synthetic/images/" + tmp):
# os.mkdir("CIFAR_10/synthetic/images/" + tmp)
if not os.path.exists("./images_gen_gpt3_prompts"):
os.mkdir("./images_gen_gpt3_prompts")
for tmp in CLASS_NAMES:
if not os.path.exists("./images_gen_gpt3_prompts/" + tmp):
os.mkdir("./images_gen_gpt3_prompts/"+ tmp)
for i in range(n_predictions):
for prompt_index, prompt in enumerate(prompts):
prompt_indx = prompt_index//5
# print(prompt, prompt_indx, CLASS_NAMES[prompt_indx])
with autocast("cuda"):
image = pipe(prompt, height=512, width=512)["sample"][0]
now = datetime.now()
time = now.strftime("%Y%m%d_%H%M%S")
img_name = CLASS_NAMES[prompt_indx] + "_" + time + "_" + str(i) + ".png"
# print("***" + "generated_images/images/" + prompt + "/" + img_name + "***")
# image.save("generated_images_prompting/images/" + CLASS_NAMES[i] + "/" + img_name)
# image.save("CIFAR_10/synthetic/images/" + CLASS_NAMES[prompt_indx] + "/" + img_name)
image.save("./images_gen_gpt3_prompts/" + CLASS_NAMES[prompt_indx] + "/" + img_name)
if i % 10 == 0:
print(str(i) + " completed")