diff --git a/src/diffusers/pipelines/stable_diffusion_xl_trt/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl_trt/pipeline_stable_diffusion_xl.py index 6a516b892ee5..16b009db48f3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl_trt/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl_trt/pipeline_stable_diffusion_xl.py @@ -169,6 +169,7 @@ def __init__( optimized_model_dir: Optional[str] = None, version: Optional[str] = 'xl-1.0', pipeline_type: PIPELINE_TYPE = PIPELINE_TYPE.SD_XL_BASE, + max_batchsize: int = 1, ): super().__init__() @@ -199,22 +200,26 @@ def __init__( stream = torch.cuda.current_stream().cuda_stream clip_runner = CLIPRunner(framework_model_dir=self.optimized_model_dir, output_hidden_states=True, - version=version, pipeline_type=pipeline_type, stream=stream) - clip_obj = clip_runner.make_clip() - self.base_clip_engine = clip_runner.load_engine(clip_obj, batch_size=1) + version=version, pipeline_type=pipeline_type, stream=stream, + max_batch_size=max_batchsize) + self.clip_obj = clip_runner.make_clip() + self.base_clip_engine = clip_runner.load_engine(self.clip_obj) clip2_runner = CLIP2Runner(framework_model_dir=self.optimized_model_dir, output_hidden_states=True, - version=version, pipeline_type=pipeline_type, stream=stream) - clip2_obj = clip2_runner.make_clip_with_proj() - self.base_clip2_engine = clip2_runner.load_engine(clip2_obj, batch_size=1) + version=version, pipeline_type=pipeline_type, stream=stream, + max_batch_size=max_batchsize) + self.clip2_obj = clip2_runner.make_clip_with_proj() + self.base_clip2_engine = clip2_runner.load_engine(self.clip2_obj) self.unetxl_runner = UNETXLRunnerInfer(framework_model_dir=self.optimized_model_dir, version=version, - scheduler=None, pipeline_type=pipeline_type, stream=stream) + scheduler=None, pipeline_type=pipeline_type, stream=stream, + max_batch_size=max_batchsize) self.base_unetxl_engine = self.unetxl_runner.load_engine() # The width/height for which the TRT modules are initialised self.trt_width = None self.trt_height = None + self.trt_batch = None self.vae_runner = ImageOnlyVaeRunner(self.vae) self.vae_runner.setup_model() @@ -223,24 +228,28 @@ def __init__( # Complete warmups - should only be necessary to warmup with max img sizes # TODO check this is the case - clip_runner.warmup(self.base_clip_engine, clip_obj) - clip2_runner.warmup(self.base_clip2_engine, clip2_obj) + clip_runner.warmup(self.base_clip_engine, self.clip_obj) + clip2_runner.warmup(self.base_clip2_engine, self.clip2_obj) self.unetxl_runner.warmup(self.base_unetxl_engine, 1024, 1024) - self.vae_runner.warmup(1024, 1024, batch_size=1) + self.vae_runner.warmup(1024, 1024) torch.cuda.synchronize() # Wait for warmup nonsense to finish. # Reinitialise the TRT modules to operate on the given width/height, if needed. # This operation is very expensive, but less expensive than completely recreating the pipeline. If you have a # workload that's mostly images of the same size with the occasional outlier, then exploiting this functionality # might be beneficial. - def possibly_reinitialise_tensorrt(self, w, h): - if self.trt_width == w and self.trt_height == h: + def possibly_reinitialise_tensorrt(self, w, h, bs): + if self.trt_width == w and self.trt_height == h and self.trt_batch == bs: return self.trt_height = h self.trt_width = w + self.trt_batch = bs - self.base_unetxl_engine.allocate_buffers(shape_dict=self.unetxl_runner.get_shape_dict(h, w), device="cuda") + self.base_clip_engine.allocate_buffers(shape_dict=self.clip_obj.get_shape_dict(bs, h, w), device=self.device) + self.base_clip2_engine.allocate_buffers(shape_dict=self.clip2_obj.get_shape_dict(bs, h, w), device=self.device) + self.base_unetxl_engine.allocate_buffers(shape_dict=self.unetxl_runner.get_shape_dict(h, w, bs), + device=self.device) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -863,8 +872,6 @@ def __call__( `tuple`. When returning a tuple, the first element is a list with the generated images. """ - self.possibly_reinitialise_tensorrt(width, height) - callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) @@ -918,6 +925,8 @@ def __call__( else: batch_size = prompt_embeds.shape[0] + self.possibly_reinitialise_tensorrt(width, height, batch_size) + device = self._execution_device # 3. Encode input prompt @@ -1062,20 +1071,20 @@ def __call__( progress_bar.update() if output_type == "latent": - image = latents + images = latents else: - image = self.vae_runner.run(latents) + images = self.vae_runner.run(latents) # apply watermark if available if self.watermark is not None: - image = self.watermark.apply_watermark(image) + images = self.watermark.apply_watermark(images) - image = self.image_processor.postprocess(image, output_type=output_type) + images = self.image_processor.postprocess(images, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: - return (image,) + return (images,) - return StableDiffusionXLPipelineOutput(images=image) + return StableDiffusionXLPipelineOutput(images=images) diff --git a/src/diffusers/pipelines/stable_diffusion_xl_trt/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl_trt/pipeline_stable_diffusion_xl_img2img.py index 9db89e0f9123..83d9dee0fc8d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl_trt/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl_trt/pipeline_stable_diffusion_xl_img2img.py @@ -187,6 +187,7 @@ def __init__( optimized_model_dir: Optional[str] = None, version: Optional[str] = 'xl-1.0', pipeline_type: PIPELINE_TYPE = PIPELINE_TYPE.SD_XL_BASE, + max_batchsize: int = 1, ): super().__init__() @@ -216,22 +217,24 @@ def __init__( stream = torch.cuda.current_stream().cuda_stream - self.unetxl_runner = UNETXLRunnerInfer(framework_model_dir=self.optimized_model_dir, version=version, scheduler=None, - pipeline_type=pipeline_type, stream=stream) + self.unetxl_runner = UNETXLRunnerInfer(framework_model_dir=self.optimized_model_dir, version=version, + scheduler=None, pipeline_type=pipeline_type, stream=stream, + max_batch_size=max_batchsize) self.ref_unetxl_engine = self.unetxl_runner.load_engine() self.ref_clip_engine = None if pipeline_type == PIPELINE_TYPE.SD_XL_BASE: clip_runner = CLIPRunner(framework_model_dir=self.optimized_model_dir, output_hidden_states=True, - version=version, - pipeline_type=pipeline_type, stream=stream) - clip_obj = clip_runner.make_clip() - self.ref_clip_engine = clip_runner.load_engine(clip_obj, batch_size=1) + version=version, pipeline_type=pipeline_type, stream=stream, + max_batch_size=max_batchsize) + self.clip_obj = clip_runner.make_clip() + self.ref_clip_engine = clip_runner.load_engine(self.clip_obj) - clip2_runner = CLIP2Runner(framework_model_dir=self.optimized_model_dir, output_hidden_states=True, version=version, - pipeline_type=pipeline_type, stream=stream) - clip2_obj = clip2_runner.make_clip_with_proj() - self.ref_clip2_engine = clip2_runner.load_engine(clip2_obj, batch_size=1) + clip2_runner = CLIP2Runner(framework_model_dir=self.optimized_model_dir, output_hidden_states=True, + version=version, pipeline_type=pipeline_type, stream=stream, + max_batch_size=max_batchsize) + self.clip2_obj = clip2_runner.make_clip_with_proj() + self.ref_clip2_engine = clip2_runner.load_engine(self.clip2_obj) self.vae_runner = ImageOnlyVaeRunner(self.vae) self.vae_runner.setup_model() @@ -241,13 +244,14 @@ def __init__( # The width/height for which the TRT modules are initialised self.trt_width = None self.trt_height = None + self.trt_batch = None self.unetxl_runner.warmup(self.ref_unetxl_engine, 1024, 1024) - self.vae_runner.warmup(1024, 1024, batch_size=1) + self.vae_runner.warmup(1024, 1024) if self.ref_clip_engine is not None: - clip_runner.warmup(self.ref_clip_engine, clip_obj) - clip2_runner.warmup(self.ref_clip2_engine, clip2_obj) + clip_runner.warmup(self.ref_clip_engine, self.clip_obj) + clip2_runner.warmup(self.ref_clip2_engine, self.clip2_obj) torch.cuda.synchronize() # Wait for warmup nonsense to finish. @@ -255,14 +259,20 @@ def __init__( # This operation is very expensive, but less expensive than completely recreating the pipeline. If you have a # workload that's mostly images of the same size with the occasional outlier, then exploiting this functionality # might be beneficial. - def possibly_reinitialise_tensorrt(self, w, h): - if self.trt_width == w and self.trt_height == h: + def possibly_reinitialise_tensorrt(self, w, h, bs): + if self.trt_width == w and self.trt_height == h and self.trt_batch == bs: return self.trt_height = h self.trt_width = w + self.trt_batch = bs - self.ref_unetxl_engine.allocate_buffers(shape_dict=self.unetxl_runner.get_shape_dict(h, w), device="cuda") + if self.ref_clip_engine is not None: + self.ref_clip_engine.allocate_buffers(shape_dict=self.clip_obj.get_shape_dict(bs, h, w), + device=self.device) + self.ref_clip2_engine.allocate_buffers(shape_dict=self.clip2_obj.get_shape_dict(bs, h, w), device=self.device) + self.ref_unetxl_engine.allocate_buffers(shape_dict=self.unetxl_runner.get_shape_dict(h, w, bs), + device=self.device) # TODO: this one is delightfully inconsistent. # self.base_vae_runner = VAERunner(framework_model_dir=self.optimized_model_dir, version=version, @@ -1102,6 +1112,23 @@ def __call__( else: batch_size = prompt_embeds.shape[0] + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # Deduce whether the input is a latent or an image. Both are accepted because _reasons_ + if image.shape[1] == 4: + # It's a latent. + height, width = image.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + else: + height = image.shape[2] + width = image.shape[3] + + # Needs to be initialised with the actual *image* size. Note that the "image" parameter may actually + # be latents, because everything must be as confusing as possible. + self.possibly_reinitialise_tensorrt(width, height, batch_size) + device = self._execution_device # 3. Encode input prompt @@ -1129,23 +1156,6 @@ def __call__( clip_skip=self.clip_skip, ) - # 4. Preprocess image - image = self.image_processor.preprocess(image) - - # Deduce whether the input is a latent or an image. Both are accepted because _reasons_ - if image.shape[1] == 4: - # It's a latent. - height, width = image.shape[-2:] - height = height * self.vae_scale_factor - width = width * self.vae_scale_factor - else: - height = image.shape[2] - width = image.shape[3] - - # Needs to be initialised with the actual *image* size. Note that the "image" parameter may actually - # be latents, because everything must be as confusing as possible. - self.possibly_reinitialise_tensorrt(width, height) - stream = torch.cuda.current_stream().cuda_stream # 5. Prepare timesteps @@ -1285,21 +1295,21 @@ def denoising_value_valid(dnv): progress_bar.update() if not output_type == "latent": - image = self.vae_runner.run(latents) + images = self.vae_runner.run(latents) else: - image = latents - return StableDiffusionXLPipelineOutput(images=image) + images = latents + return StableDiffusionXLPipelineOutput(images=images) # apply watermark if available if self.watermark is not None: - image = self.watermark.apply_watermark(image) + images = self.watermark.apply_watermark(images) - image = self.image_processor.postprocess(image, output_type=output_type) + images = self.image_processor.postprocess(images, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: - return (image,) + return (images,) - return StableDiffusionXLPipelineOutput(images=image) + return StableDiffusionXLPipelineOutput(images=images)