-
Notifications
You must be signed in to change notification settings - Fork 0
/
stability_ai.py
153 lines (129 loc) · 6.27 KB
/
stability_ai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import io
import os
import warnings
from PIL import Image
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
"""
https://github.com/Stability-AI/REST-API/issues/21
- width: 512
height: 512
- width: 768
height: 512
- width: 512
height: 768
- width: 640
height: 512
- width: 512
height: 640
- width: 896
height: 512
- width: 512
height: 896
"""
def clamp(num, min_value, max_value):
return max(min(num, max_value), min_value)
def _preprocess_image(img):
"""init_image: image dimensions must be multiples of 64"""
# NOTE: 512, 640, 768, 896 = 128 * [4,5,6,7]
if img is None:
return img
width, height = img.size
if width <= height:
w = 512
h = int((height/width) * w)
else:
h = 512
w = int((width/height) * h)
if w != width or h != height:
img = img.resize((w,h))
width, height = img.size
new_width, new_height = clamp(int(width/128), 4, 7) * 128, clamp(int(height/128), 4, 7) * 128
if new_width != width or new_height != height:
left, top = int((width - new_width)/2), int((height-new_height)/2)
img = img.crop((left, top, left+new_width, top+new_height))
print(f'Image is resized: {img.size}')
return img
def bot(user_message, history, refine=False, prompt_strength=0.6,
translate=False, llm=None):
from utils import parse_message, format_to_message
msg_dict = parse_message(user_message)
# Stability AI generates better results for English
translate = False if llm is None else translate
if translate:
msg_dict["text"] = llm.predict(f'Translate the following sentence into English: {msg_dict["text"]}')
init_image = None
if "images" in msg_dict and len(msg_dict["images"]) >= 1:
init_image = msg_dict["images"][-1]
# set init_image to last image in bot response if in refine mode
if init_image is None and refine:
for _, _bot_msg in history[:-1][::-1]:
_msg_dict = parse_message(_bot_msg)
if init_image is None and "images" in _msg_dict and len(_msg_dict["images"]) >= 1:
init_image = _msg_dict["images"][-1]
break
if init_image is not None:
print(f'Refine from {init_image}')
img = generate(msg_dict["text"],
init_image=_preprocess_image(Image.open(init_image)) if init_image is not None else None,
start_schedule=prompt_strength
)
if img is not None:
import tempfile
fname = tempfile.NamedTemporaryFile(prefix='gradio/stability_ai-', suffix='.png').name
img.save(fname)
text = f'<small>Translated prompt: *{msg_dict["text"]}*</small>' if translate else ""
bot_message = format_to_message(dict(text=text, images=[fname]))
else:
bot_message = "Sorry, stability.ai failed to generate image."
return bot_message
def generate(prompt, init_image=None, mask_image=None, start_schedule=0.6, width=512, height=512,
engine="stable-diffusion-xl-beta-v2-2-2"):
"""
Parameters:
init_image: PIL Image
start_schedule: the strength of our prompt in relation to our initial image.
0.0 -- initial image, 1.0 -- prompt
Returns:
img : PIL Image or None
"""
# Set up our connection to the API.
stability_api = client.StabilityInference(
key=os.environ['STABILITY_API_KEY'], # API Key reference.
verbose=True, # Print debug messages.
engine=engine, # Set the engine to use for generation.
# Available engines: stable-diffusion-v1 stable-diffusion-v1-5 stable-diffusion-512-v2-0 stable-diffusion-768-v2-0
# stable-diffusion-512-v2-1 stable-diffusion-768-v2-1 stable-diffusion-xl-beta-v2-2-2 stable-inpainting-v1-0 stable-inpainting-512-v2-0
)
init_image = _preprocess_image(init_image)
mask_image = _preprocess_image(mask_image)
kwargs = dict(init_image=init_image, mask_image=mask_image, start_schedule=start_schedule)
# Set up our initial generation parameters.
answers = stability_api.generate(
prompt=prompt,
**kwargs,
# init_image=img, # Assign our previously generated img as our Initial Image for transformation.
# start_schedule=1.0, # Set the strength of our prompt in relation to our initial image.
# seed=123467458, # If attempting to transform an image that was previously generated with our API,
# # initial images benefit from having their own distinct seed rather than using the seed of the original image generation.
steps=30, # Amount of inference steps performed on image generation. Defaults to 30.
cfg_scale=8.0, # Influences how strongly your generation is guided to match your prompt.
# Setting this value higher increases the strength in which it tries to match your prompt.
# Defaults to 7.0 if not specified.
width=width, # Generation width, defaults to 512 if not included.
height=height, # Generation height, defaults to 512 if not included.
sampler=generation.SAMPLER_K_DPMPP_2M # Choose which sampler we want to denoise our generation with.
# Defaults to k_dpmpp_2m if not specified. Clip Guidance only supports ancestral samplers.
# (Available Samplers: ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_dpmpp_2s_ancestral, k_lms, k_dpmpp_2m, k_dpmpp_sde)
)
# Set up our warning to print to the console if the adult content classifier is tripped.
# If adult content classifier is not tripped, display generated image.
for resp in answers:
for artifact in resp.artifacts:
if artifact.finish_reason == generation.FILTER:
warnings.warn(
"Your request activated the API's safety filters and could not be processed."
"Please modify the prompt and try again.")
if artifact.type == generation.ARTIFACT_IMAGE:
img = Image.open(io.BytesIO(artifact.binary)) # Set our resulting initial image generation as 'img2' to avoid overwriting our previous 'img' generation.
return img