-
Notifications
You must be signed in to change notification settings - Fork 3
/
predict.py
139 lines (123 loc) · 4.43 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from cog import BasePredictor, Input, Path
import os
import sys
import torch
from PIL import Image
from clip_interrogator import Interrogator, Config
from diffusers import (
DDIMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
PNDMScheduler,
StableDiffusionXLImg2ImgPipeline
)
sys.path.append('/root/blip')
MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
SDXL_MODEL_CACHE = "sdxl-cache/"
class KarrasDPM:
def from_config(config):
return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True)
SCHEDULERS = {
"DDIM": DDIMScheduler,
"DPMSolverMultistep": DPMSolverMultistepScheduler,
"HeunDiscrete": HeunDiscreteScheduler,
"KarrasDPM": KarrasDPM,
"K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler,
"K_EULER": EulerDiscreteScheduler,
"PNDM": PNDMScheduler,
}
class Predictor(BasePredictor):
def setup(self):
print("Loading CLIP pipeline...")
self.ci = Interrogator(
Config(
clip_model_name="ViT-H-14/laion2b_s32b_b79k",
clip_model_path='cache',
device='cuda:0',
)
)
print("Loading sdxl txt2img pipeline...")
self.txt2img_pipe = DiffusionPipeline.from_pretrained(
SDXL_MODEL_CACHE,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
print("Loading SDXL img2img pipeline...")
self.img2img_pipe = StableDiffusionXLImg2ImgPipeline(
vae=self.txt2img_pipe.vae,
text_encoder=self.txt2img_pipe.text_encoder,
text_encoder_2=self.txt2img_pipe.text_encoder_2,
tokenizer=self.txt2img_pipe.tokenizer,
tokenizer_2=self.txt2img_pipe.tokenizer_2,
unet=self.txt2img_pipe.unet,
scheduler=self.txt2img_pipe.scheduler,
)
self.img2img_pipe.to("cuda")
# Remove first part of prompt
def remove_first_part(self, input_string: str):
parts = input_string.split(',', 1)
if len(parts) > 1:
return parts[1].strip()
else:
return input_string.strip()
def predict(
self,
image: Path = Input(description="Input image"),
# prompt_strength: float = Input(
# description="Prompt strength when using img2img / inpaint. 1.0 corresponds to full destruction of information in image",
# ge=0.0,
# le=1.0,
# default=0.8,
# ),
# prompt: str = Input(description="Prompt text", default=None),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> Path:
"""Run a single prediction on the model"""
# Hardcoded params
width=1024
height=1024
scheduler="K_EULER_ANCESTRAL"
# Run CLIP
image = Image.open(str(image)).convert("RGB")
# Resize input image to SDXL size
image = image.resize((1024, 1024))
clip_txt = self.ci.interrogate(image)
# print("CLIP OG:"+clip_txt)
# clip_txt = self.remove_first_part(clip_txt)
# print("CLIP Filtered:"+clip_txt)
if seed is None:
seed = int.from_bytes(os.urandom(4), "big")
print(f"Using seed: {seed}")
sdxl_kwargs = {}
# print("img2img mode")
sdxl_kwargs["image"] = image
sdxl_kwargs["strength"] = 0.9
pipe = self.img2img_pipe
pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
generator = torch.Generator("cuda").manual_seed(seed)
full_prompt = "A cartoon portait picture, full art illustration, "
#Check if user added prompt
# if prompt is not None:
# filtered_prompt = prompt.replace(")", "")
# full_prompt += filtered_prompt+", "
full_prompt += clip_txt
print("Final prompt: " + full_prompt)
common_args = {
"prompt": full_prompt,
"negative_prompt": "",
"guidance_scale": 7.5,
"generator": generator,
"num_inference_steps": 40,
}
output = pipe(**common_args, **sdxl_kwargs)
output_path = f"/tmp/output.png"
img_out = output.images[0]
img_out.save(output_path)
return Path(output_path)
# return clip_txt