Skip to content

Commit

Permalink
Add controlnet to 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Feb 9, 2024
1 parent 01575a8 commit 1045116
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 46 deletions.
138 changes: 132 additions & 6 deletions apps/shark_studio/api/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,29 @@
import os
import PIL
import numpy as np
from apps.shark_studio.modules.pipeline import SharkPipelineBase
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
)
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
clean_device_info,
get_iree_target_triple,
)
from apps.shark_studio.web.utils.file_utils import (
safe_name,
get_resource_path,
get_checkpoints_path,
)
import cv2
from datetime import datetime
from PIL import Image
from gradio.components.image_editor import (
EditorValue,
)

# from turbine_models.custom_models.sd_inference import export_controlnet_model, ControlNetModel
import gc

class control_adapter:
def __init__(
Expand All @@ -20,7 +34,14 @@ def __init__(
self.model = None

def export_control_adapter_model(model_keyword):
return None
if model_keyword == "canny":
return export_controlnet_model(
ControlNetModel("lllyasviel/control_v11p_sd15_canny"),
"lllyasviel/control_v11p_sd15_canny",
1,
512,
512,
)

def export_xl_control_adapter_model(model_keyword):
return None
Expand All @@ -36,9 +57,16 @@ def __init__(
def export_controlnet_model(model_keyword):
return None

ireec_flags = [
"--iree-flow-collapse-reduction-dims",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))",
"--iree-flow-inline-constants-max-byte-length=1" # Stopgap, take out when not needed
]

control_adapter_map = {
"sd15": {
"runwayml/stable-diffusion-v1-5": {
"canny": {"initializer": control_adapter.export_control_adapter_model},
"openpose": {"initializer": control_adapter.export_control_adapter_model},
"scribble": {"initializer": control_adapter.export_control_adapter_model},
Expand All @@ -64,14 +92,112 @@ def __init__(
):
self.model = hf_model_id
self.device = device
self.compiled_model = None

def compile(self):
if self.compiled_model is not None:
return
if "canny" in self.model:
return
if "openpose" in self.model:
pass
print("compile not implemented for preprocessor.")
return

def run(self, inputs):
print("run not implemented for preprocessor.")
return inputs
if self.compiled_model is None:
self.compile()
if "canny" in self.model:
out = cv2.Canny(*inputs)
return out
if "openpose" in self.model:
self.compiled_model(*inputs)

def __call__(self, *inputs):
return self.run(inputs)


class SharkControlnetPipeline(SharkPipelineBase):
def __init__(
self,
# model_map: dict,
# static_kwargs: dict,
device: str,
# import_mlir: bool = True,
):
self.model_map = control_adapter_map
self.pipe_map = {}
# self.static_kwargs = static_kwargs
self.static_kwargs = {}
self.triple = get_iree_target_triple(device)
self.device, self.device_id = clean_device_info(device)
self.import_mlir = False
self.iree_module_dict = {}
self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp"))
if not os.path.exists(self.tmp_dir):
os.mkdir(self.tmp_dir)
self.tempfiles = {}
self.pipe_vmfb_path = ""
self.ireec_flags = ireec_flags

def get_compiled_map(self, model, init_kwargs={}):
self.pipe_map[model] = {}
if model in self.iree_module_dict:
return
elif model not in self.tempfiles:
# if model in self.static_kwargs[model]:
# init_kwargs = self.static_kwargs[model]
init_kwargs = {}
# for key in self.static_kwargs["pipe"]:
# if key not in init_kwargs:
# init_kwargs[key] = self.static_kwargs["pipe"][key]
self.import_torch_ir(model, init_kwargs)
self.get_compiled_map(model)
else:
# weights_path = self.get_io_params(model)

self.iree_module_dict[model] = get_iree_compiled_module(
self.tempfiles[model],
device=self.device,
frontend="torch",
mmap=True,
# external_weight_file=weights_path,
external_weight_file=None,
extra_args=self.ireec_flags,
write_to=os.path.join(self.pipe_vmfb_path, model + ".vmfb")
)

def import_torch_ir(self, model, kwargs):
# torch_ir = self.model_map[model]["initializer"](
# **self.safe_dict(kwargs), compile_to="torch"
# )
tmp_kwargs = {
"model_keyword": "canny"
}
torch_ir = self.model_map["sd15"][model]["initializer"](
**self.safe_dict(tmp_kwargs) #, compile_to="torch"
)

self.tempfiles[model] = os.path.join(
self.tmp_dir, f"{model}.torch.tempfile"
)

with open(self.tempfiles[model], "w+") as f:
f.write(torch_ir)
del torch_ir
gc.collect()
return

def get_precompiled(self, model):
vmfbs = []
for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path):
vmfbs.extend(filenames)
break
for file in vmfbs:
if model in file:
self.pipe_map[model]["vmfb_path"] = os.path.join(
self.pipe_vmfb_path, file
)
return


def cnet_preview(model, input_image):
Expand Down
86 changes: 70 additions & 16 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import Path
from random import randint
from turbine_models.custom_models.sd_inference import clip, unet, vae
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.api.controlnet import control_adapter_map, SharkControlnetPipeline
from apps.shark_studio.web.utils.state import status_label
from apps.shark_studio.web.utils.file_utils import (
safe_name,
Expand Down Expand Up @@ -112,10 +112,10 @@ def __init__(
"unet": {
"hf_model_name": base_model_id,
"unet_model": unet.UnetModel(
hf_model_name=base_model_id, hf_auth_token=None
hf_model_name=base_model_id, hf_auth_token=None, is_controlled=False,
),
"batch_size": batch_size,
# "is_controlled": is_controlled,
"is_controlled": is_controlled,
# "num_loras": num_loras,
"height": height,
"width": width,
Expand All @@ -126,7 +126,6 @@ def __init__(
"hf_model_name": base_model_id,
"vae_model": vae.VaeModel(
hf_model_name=base_model_id,
custom_vae=custom_vae,
),
"batch_size": batch_size,
"height": height,
Expand All @@ -137,7 +136,6 @@ def __init__(
"hf_model_name": base_model_id,
"vae_model": vae.VaeModel(
hf_model_name=base_model_id,
custom_vae=custom_vae,
),
"batch_size": batch_size,
"height": height,
Expand All @@ -163,6 +161,7 @@ def __init__(
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
del static_kwargs
gc.collect()
self.controlnet = SharkControlnetPipeline(device)

def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
print(f"\n[LOG] Preparing pipeline...")
Expand Down Expand Up @@ -291,6 +290,7 @@ def produce_img_latents(
mask=None,
masked_image_latents=None,
return_all_latents=False,
controlnet_latents=None
):
# self.status = SD_STATE_IDLE
step_time_sum = 0
Expand All @@ -299,6 +299,7 @@ def produce_img_latents(
text_embeddings_numpy = text_embeddings.detach().numpy()
guidance_scale = torch.Tensor([guidance_scale]).to(self.dtype)
self.load_submodels(["unet"])
control_scale = torch.tensor(1.0, dtype=self.dtype)
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(self.dtype).detach().numpy()
Expand All @@ -319,15 +320,52 @@ def produce_img_latents(

# Profiling Unet.
# profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.run(
"unet",
[
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
],
)
if controlnet_latents is None:
noise_pred = self.run(
"unet",
[
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
],
)
else:
noise_pred = self.run(
"unet",
[
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
controlnet_latents[0],
controlnet_latents[1],
controlnet_latents[2],
controlnet_latents[3],
controlnet_latents[4],
controlnet_latents[5],
controlnet_latents[6],
controlnet_latents[7],
controlnet_latents[8],
controlnet_latents[9],
controlnet_latents[10],
controlnet_latents[11],
controlnet_latents[12],
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
],
)
# end_profiling(profile_device)

if cpu_scheduling:
Expand Down Expand Up @@ -388,6 +426,7 @@ def generate_images(
repeatable_seeds,
resample_type,
control_mode,
controlnet_models,
hints,
):
# TODO: Batched args
Expand Down Expand Up @@ -432,12 +471,24 @@ def generate_images(
strength=strength,
)

hints = [Image.load_file(x) for x in hints]
controlnet_latents = None
for (model, hint) in zip(controlnet_models, hints):
# if model not in self.controlnets:
# continue
self.controlnet.get_compiled_map("canny")
latent = self.controlnets[model].run(hint)
if controlnet_latents is None:
controlnet_latents = latent
break

latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
cpu_scheduling=True, # until we have schedulers through Turbine
controlnet_latents=controlnet_latents,
)

# Img latents -> PIL images
Expand Down Expand Up @@ -511,6 +562,7 @@ def shark_sd_fn(
is_controlled = False
control_mode = None
hints = []
controlnet_models = []
num_loras = 0
for i in embeddings:
num_loras += 1 if embeddings[i] else 0
Expand All @@ -525,16 +577,17 @@ def shark_sd_fn(
}
else:
adapters[f"control_adapter_{model}"] = {
"hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][
"hf_id": +["stabilityai/stable-diffusion-xl-1.0"][
model
],
"strength": controlnets["strength"][i],
}
if model is not None:
is_controlled = True
controlnet_models.append(model)
control_mode = controlnets["control_mode"]
for i in controlnets["hint"]:
hints.append[i]
hints.append(i)

submit_pipe_kwargs = {
"base_model_id": base_model_id,
Expand Down Expand Up @@ -567,6 +620,7 @@ def shark_sd_fn(
"repeatable_seeds": repeatable_seeds,
"resample_type": resample_type,
"control_mode": control_mode,
"controlnet_models": controlnet_models,
"hints": hints,
}
if (
Expand Down
2 changes: 1 addition & 1 deletion apps/shark_studio/modules/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from msvcrt import kbhit
# from msvcrt import kbhit
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
Expand Down
Loading

0 comments on commit 1045116

Please sign in to comment.