Skip to content

Allow for batched models to work in trt pipelines #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: nihanth/batching
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
optimized_model_dir: Optional[str] = None,
version: Optional[str] = 'xl-1.0',
pipeline_type: PIPELINE_TYPE = PIPELINE_TYPE.SD_XL_BASE,
max_batchsize: int = 1,
):
super().__init__()

Expand Down Expand Up @@ -199,22 +200,26 @@ def __init__(

stream = torch.cuda.current_stream().cuda_stream
clip_runner = CLIPRunner(framework_model_dir=self.optimized_model_dir, output_hidden_states=True,
version=version, pipeline_type=pipeline_type, stream=stream)
clip_obj = clip_runner.make_clip()
self.base_clip_engine = clip_runner.load_engine(clip_obj, batch_size=1)
version=version, pipeline_type=pipeline_type, stream=stream,
max_batch_size=max_batchsize)
self.clip_obj = clip_runner.make_clip()
self.base_clip_engine = clip_runner.load_engine(self.clip_obj)

clip2_runner = CLIP2Runner(framework_model_dir=self.optimized_model_dir, output_hidden_states=True,
version=version, pipeline_type=pipeline_type, stream=stream)
clip2_obj = clip2_runner.make_clip_with_proj()
self.base_clip2_engine = clip2_runner.load_engine(clip2_obj, batch_size=1)
version=version, pipeline_type=pipeline_type, stream=stream,
max_batch_size=max_batchsize)
self.clip2_obj = clip2_runner.make_clip_with_proj()
self.base_clip2_engine = clip2_runner.load_engine(self.clip2_obj)

self.unetxl_runner = UNETXLRunnerInfer(framework_model_dir=self.optimized_model_dir, version=version,
scheduler=None, pipeline_type=pipeline_type, stream=stream)
scheduler=None, pipeline_type=pipeline_type, stream=stream,
max_batch_size=max_batchsize)
self.base_unetxl_engine = self.unetxl_runner.load_engine()

# The width/height for which the TRT modules are initialised
self.trt_width = None
self.trt_height = None
self.trt_batch = None

self.vae_runner = ImageOnlyVaeRunner(self.vae)
self.vae_runner.setup_model()
Expand All @@ -223,24 +228,28 @@ def __init__(

# Complete warmups - should only be necessary to warmup with max img sizes
# TODO check this is the case
clip_runner.warmup(self.base_clip_engine, clip_obj)
clip2_runner.warmup(self.base_clip2_engine, clip2_obj)
clip_runner.warmup(self.base_clip_engine, self.clip_obj)
clip2_runner.warmup(self.base_clip2_engine, self.clip2_obj)
self.unetxl_runner.warmup(self.base_unetxl_engine, 1024, 1024)
self.vae_runner.warmup(1024, 1024, batch_size=1)
self.vae_runner.warmup(1024, 1024)
torch.cuda.synchronize() # Wait for warmup nonsense to finish.

# Reinitialise the TRT modules to operate on the given width/height, if needed.
# This operation is very expensive, but less expensive than completely recreating the pipeline. If you have a
# workload that's mostly images of the same size with the occasional outlier, then exploiting this functionality
# might be beneficial.
def possibly_reinitialise_tensorrt(self, w, h):
if self.trt_width == w and self.trt_height == h:
def possibly_reinitialise_tensorrt(self, w, h, bs):
if self.trt_width == w and self.trt_height == h and self.trt_batch == bs:
return

self.trt_height = h
self.trt_width = w
self.trt_batch = bs

self.base_unetxl_engine.allocate_buffers(shape_dict=self.unetxl_runner.get_shape_dict(h, w), device="cuda")
self.base_clip_engine.allocate_buffers(shape_dict=self.clip_obj.get_shape_dict(bs, h, w), device=self.device)
self.base_clip2_engine.allocate_buffers(shape_dict=self.clip2_obj.get_shape_dict(bs, h, w), device=self.device)
self.base_unetxl_engine.allocate_buffers(shape_dict=self.unetxl_runner.get_shape_dict(h, w, bs),
device=self.device)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
Expand Down Expand Up @@ -863,8 +872,6 @@ def __call__(
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""

self.possibly_reinitialise_tensorrt(width, height)

callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)

Expand Down Expand Up @@ -918,6 +925,8 @@ def __call__(
else:
batch_size = prompt_embeds.shape[0]

self.possibly_reinitialise_tensorrt(width, height, batch_size)

device = self._execution_device

# 3. Encode input prompt
Expand Down Expand Up @@ -1062,20 +1071,20 @@ def __call__(
progress_bar.update()

if output_type == "latent":
image = latents
images = latents
else:
image = self.vae_runner.run(latents)
images = self.vae_runner.run(latents)

# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
images = self.watermark.apply_watermark(images)

image = self.image_processor.postprocess(image, output_type=output_type)
images = self.image_processor.postprocess(images, output_type=output_type)

# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (image,)
return (images,)

return StableDiffusionXLPipelineOutput(images=image)
return StableDiffusionXLPipelineOutput(images=images)
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __init__(
optimized_model_dir: Optional[str] = None,
version: Optional[str] = 'xl-1.0',
pipeline_type: PIPELINE_TYPE = PIPELINE_TYPE.SD_XL_BASE,
max_batchsize: int = 1,
):
super().__init__()

Expand Down Expand Up @@ -216,22 +217,24 @@ def __init__(

stream = torch.cuda.current_stream().cuda_stream

self.unetxl_runner = UNETXLRunnerInfer(framework_model_dir=self.optimized_model_dir, version=version, scheduler=None,
pipeline_type=pipeline_type, stream=stream)
self.unetxl_runner = UNETXLRunnerInfer(framework_model_dir=self.optimized_model_dir, version=version,
scheduler=None, pipeline_type=pipeline_type, stream=stream,
max_batch_size=max_batchsize)
self.ref_unetxl_engine = self.unetxl_runner.load_engine()

self.ref_clip_engine = None
if pipeline_type == PIPELINE_TYPE.SD_XL_BASE:
clip_runner = CLIPRunner(framework_model_dir=self.optimized_model_dir, output_hidden_states=True,
version=version,
pipeline_type=pipeline_type, stream=stream)
clip_obj = clip_runner.make_clip()
self.ref_clip_engine = clip_runner.load_engine(clip_obj, batch_size=1)
version=version, pipeline_type=pipeline_type, stream=stream,
max_batch_size=max_batchsize)
self.clip_obj = clip_runner.make_clip()
self.ref_clip_engine = clip_runner.load_engine(self.clip_obj)

clip2_runner = CLIP2Runner(framework_model_dir=self.optimized_model_dir, output_hidden_states=True, version=version,
pipeline_type=pipeline_type, stream=stream)
clip2_obj = clip2_runner.make_clip_with_proj()
self.ref_clip2_engine = clip2_runner.load_engine(clip2_obj, batch_size=1)
clip2_runner = CLIP2Runner(framework_model_dir=self.optimized_model_dir, output_hidden_states=True,
version=version, pipeline_type=pipeline_type, stream=stream,
max_batch_size=max_batchsize)
self.clip2_obj = clip2_runner.make_clip_with_proj()
self.ref_clip2_engine = clip2_runner.load_engine(self.clip2_obj)

self.vae_runner = ImageOnlyVaeRunner(self.vae)
self.vae_runner.setup_model()
Expand All @@ -241,28 +244,35 @@ def __init__(
# The width/height for which the TRT modules are initialised
self.trt_width = None
self.trt_height = None
self.trt_batch = None

self.unetxl_runner.warmup(self.ref_unetxl_engine, 1024, 1024)
self.vae_runner.warmup(1024, 1024, batch_size=1)
self.vae_runner.warmup(1024, 1024)

if self.ref_clip_engine is not None:
clip_runner.warmup(self.ref_clip_engine, clip_obj)
clip2_runner.warmup(self.ref_clip2_engine, clip2_obj)
clip_runner.warmup(self.ref_clip_engine, self.clip_obj)
clip2_runner.warmup(self.ref_clip2_engine, self.clip2_obj)

torch.cuda.synchronize() # Wait for warmup nonsense to finish.

# Reinitialise the TRT modules to operate on the given width/height, if needed.
# This operation is very expensive, but less expensive than completely recreating the pipeline. If you have a
# workload that's mostly images of the same size with the occasional outlier, then exploiting this functionality
# might be beneficial.
def possibly_reinitialise_tensorrt(self, w, h):
if self.trt_width == w and self.trt_height == h:
def possibly_reinitialise_tensorrt(self, w, h, bs):
if self.trt_width == w and self.trt_height == h and self.trt_batch == bs:
return

self.trt_height = h
self.trt_width = w
self.trt_batch = bs

self.ref_unetxl_engine.allocate_buffers(shape_dict=self.unetxl_runner.get_shape_dict(h, w), device="cuda")
if self.ref_clip_engine is not None:
self.ref_clip_engine.allocate_buffers(shape_dict=self.clip_obj.get_shape_dict(bs, h, w),
device=self.device)
self.ref_clip2_engine.allocate_buffers(shape_dict=self.clip2_obj.get_shape_dict(bs, h, w), device=self.device)
self.ref_unetxl_engine.allocate_buffers(shape_dict=self.unetxl_runner.get_shape_dict(h, w, bs),
device=self.device)

# TODO: this one is delightfully inconsistent.
# self.base_vae_runner = VAERunner(framework_model_dir=self.optimized_model_dir, version=version,
Expand Down Expand Up @@ -1102,6 +1112,23 @@ def __call__(
else:
batch_size = prompt_embeds.shape[0]

# 4. Preprocess image
image = self.image_processor.preprocess(image)

# Deduce whether the input is a latent or an image. Both are accepted because _reasons_
if image.shape[1] == 4:
# It's a latent.
height, width = image.shape[-2:]
height = height * self.vae_scale_factor
width = width * self.vae_scale_factor
else:
height = image.shape[2]
width = image.shape[3]

# Needs to be initialised with the actual *image* size. Note that the "image" parameter may actually
# be latents, because everything must be as confusing as possible.
self.possibly_reinitialise_tensorrt(width, height, batch_size)

device = self._execution_device

# 3. Encode input prompt
Expand Down Expand Up @@ -1129,23 +1156,6 @@ def __call__(
clip_skip=self.clip_skip,
)

# 4. Preprocess image
image = self.image_processor.preprocess(image)

# Deduce whether the input is a latent or an image. Both are accepted because _reasons_
if image.shape[1] == 4:
# It's a latent.
height, width = image.shape[-2:]
height = height * self.vae_scale_factor
width = width * self.vae_scale_factor
else:
height = image.shape[2]
width = image.shape[3]

# Needs to be initialised with the actual *image* size. Note that the "image" parameter may actually
# be latents, because everything must be as confusing as possible.
self.possibly_reinitialise_tensorrt(width, height)

stream = torch.cuda.current_stream().cuda_stream

# 5. Prepare timesteps
Expand Down Expand Up @@ -1285,21 +1295,21 @@ def denoising_value_valid(dnv):
progress_bar.update()

if not output_type == "latent":
image = self.vae_runner.run(latents)
images = self.vae_runner.run(latents)
else:
image = latents
return StableDiffusionXLPipelineOutput(images=image)
images = latents
return StableDiffusionXLPipelineOutput(images=images)

# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
images = self.watermark.apply_watermark(images)

image = self.image_processor.postprocess(image, output_type=output_type)
images = self.image_processor.postprocess(images, output_type=output_type)

# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (image,)
return (images,)

return StableDiffusionXLPipelineOutput(images=image)
return StableDiffusionXLPipelineOutput(images=images)