Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: onnx support #1

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lambda_diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .pipelines import StableDiffusionImageEmbedPipeline
from .pipelines import StableDiffusionImageEmbedPipeline
from .pipelines import StableDiffusionImageEmbedOnnxPipeline
4 changes: 3 additions & 1 deletion lambda_diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .pipeline_stable_diffusion_im_embed import StableDiffusionImageEmbedPipeline
from .pipeline_stable_diffusion_im_embed import StableDiffusionImageEmbedPipeline
from .pipeline_stable_diffusion_im_embed_onnx import StableDiffusionImageEmbedOnnxPipeline

169 changes: 169 additions & 0 deletions lambda_diffusers/pipelines/pipeline_stable_diffusion_im_embed_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import inspect
from typing import List, Optional, Union

import numpy as np

from transformers import CLIPModel, CLIPFeatureExtractor, CLIPTokenizer, CLIPVisionModel

from diffusers.onnx_utils import OnnxRuntimeModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker

import torch
import PIL
import warnings


class StableDiffusionImageEmbedOnnxPipeline(DiffusionPipeline):
vae_decoder: OnnxRuntimeModel
unet: OnnxRuntimeModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: StableDiffusionSafetyChecker

def __init__(
self,
vae_decoder: OnnxRuntimeModel,
unet: OnnxRuntimeModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
scheduler = scheduler.set_format("np")
self.register_modules(
vae_decoder=vae_decoder,
unet=unet,
scheduler=scheduler,
)
self.feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14")
self.image_encoder = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
def __call__(
self,
input_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
eta: Optional[float] = 0.0,
latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)

# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)

if isinstance(input_image, PIL.Image.Image):
batch_size = 1
elif isinstance(input_image, list):
batch_size = len(input_image)
else:
raise ValueError(f"`input_image` has to be of type `str` or `list` but is {type(input_image)}")

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

if not isinstance(input_image, torch.FloatTensor):
input_image = self.feature_extractor(images=input_image, return_tensors="pt").to(self.device)

image_encoder_output = self.image_encoder.vision_model(input_image["pixel_values"])[1]
image_embeddings = self.image_encoder.visual_projection(image_encoder_output)
image_embeddings = image_embeddings.unsqueeze(1)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_embeddings = torch.zeros_like(image_embeddings)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])


image_embeddings = image_embeddings.cpu().detach().numpy()
# get the initial random noise unless the user supplied it
latents_shape = (batch_size, 4, height // 8, width // 8)
if latents is None:
latents = np.random.randn(*latents_shape).astype(np.float32)
elif latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")


# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
if accepts_offset:
extra_set_kwargs["offset"] = 1

self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)

# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta

for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)

# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=image_embeddings
)
noise_pred = noise_pred[0]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0]

image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))

# run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)

if output_type == "pil":
image = self.numpy_to_pil(image)

if not return_dict:
return (image, has_nsfw_concept)

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
161 changes: 161 additions & 0 deletions scripts/convert_sd_image_checkpoint_to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from pathlib import Path

import torch
from torch.onnx import export

from lambda_diffusers import StableDiffusionImageEmbedOnnxPipeline
from diffusers.onnx_utils import OnnxRuntimeModel
from packaging import version

from lambda_diffusers import StableDiffusionImageEmbedPipeline

from PIL import Image
import requests

is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")


def onnx_export(
model,
model_args: tuple,
output_path: Path,
ordered_input_names,
output_names,
dynamic_axes,
opset,
use_external_data_format=False,
):
output_path.parent.mkdir(parents=True, exist_ok=True)
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if is_torch_less_than_1_11:
export(
model,
model_args,
f=output_path.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
enable_onnx_checker=True,
opset_version=opset,
)
else:
export(
model,
model_args,
f=output_path.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=opset,
)


@torch.no_grad()
def convert_models(model_path: str, output_path: str, opset: int):
pipeline = StableDiffusionImageEmbedPipeline.from_pretrained(model_path)
output_path = Path(output_path)


# UNET
onnx_export(
pipeline.unet,
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
output_path=output_path / "unet" / "model.onnx",
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
output_names=["out_sample"], # has to be different from "sample" for correct tracing
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
},
opset=opset,
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
)

# VAE ENCODER
vae_encoder = pipeline.vae
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
onnx_export(
vae_encoder,
model_args=(torch.randn(1, 3, 512, 512), False),
output_path=output_path / "vae_encoder" / "model.onnx",
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)

# VAE DECODER
vae_decoder = pipeline.vae
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(torch.randn(1, 4, 64, 64), False),
output_path=output_path / "vae_decoder" / "model.onnx",
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)


onnx_pipeline = StableDiffusionImageEmbedOnnxPipeline(
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
)

onnx_pipeline.save_pretrained(output_path)
print("ONNX pipeline saved to", output_path)

_ = StableDiffusionImageEmbedOnnxPipeline.from_pretrained(output_path, provider="CPUExecutionProvider")
print("ONNX pipeline is loadable")


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).",
)

parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")

parser.add_argument(
"--opset",
default=14,
type=str,
help="The version of the ONNX operator set to use.",
)

args = parser.parse_args()

convert_models(args.model_path, args.output_path, args.opset)