Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Readme fix, add workflow for Stable Diffusion 3.0 (Save VAE + Create) #30

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 24 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Supports:
- SDXL
- SDXL Turbo
- Stable Video Diffusion
- Stable Video Diffusion-XT 
- Stable Video Diffusion-XT

Requirements:

Expand All @@ -30,7 +30,7 @@ to easily install them to your ComfyUI instance.

You can also manually install them by git cloning the repo to your ComfyUI/custom_nodes folder and installing the requirements like:

```
```shell
cd custom_nodes
git clone https://github.com/comfyanonymous/ComfyUI_TensorRT
cd ComfyUI_TensorRT
Expand Down Expand Up @@ -68,19 +68,19 @@ These .json files can be loaded in ComfyUI.

### Building A TensorRT Engine From a Checkpoint

1. Add a Load Checkpoint Node
2. Add either a Static Model TensorRT Conversion node or a Dynamic
1. Add a Load Checkpoint Node
2. Add either a Static Model TensorRT Conversion node or a Dynamic
Model TensorRT Conversion node to ComfyUI
3. ![](readme_images/image3.png)
4. Connect the Load Checkpoint Model output to the TensorRT Conversion
3. ![](readme_images/image3.png)
4. Connect the Load Checkpoint Model output to the TensorRT Conversion
Node Model input.
5. ![](readme_images/image5.png)
6. ![](readme_images/image2.png)
7. To help identify the converted TensorRT model, provide a meaningful
5. ![](readme_images/image5.png)
6. ![](readme_images/image2.png)
7. To help identify the converted TensorRT model, provide a meaningful
filename prefix, add this filename after “tensorrt/”
8. ![](readme_images/image9.png)
8. ![](readme_images/image9.png)

9. Click on Queue Prompt to start building the TensorRT Engines
9. Click on Queue Prompt to start building the TensorRT Engines
10. ![](readme_images/image7.png)

![](readme_images/image11.png)
Expand Down Expand Up @@ -112,33 +112,25 @@ TensorRT Engines are loaded using the TensorRT Loader node.
ComfyUI TensorRT engines are not yet compatible with ControlNets or
LoRAs. Compatibility will be enabled in a future update.

1. Add a TensorRT Loader node
2. Note, if a TensorRT Engine has been created during a ComfyUI
1. Add a TensorRT Loader node
2. Note, if a TensorRT Engine has been created during a ComfyUI
session, it will not show up in the TensorRT Loader until the
ComfyUI interface has been refreshed (F5 to refresh browser).
3. ![](readme_images/image6.png)
4. Select a TensorRT Engine from the unet_name dropdown
5. Dynamic Engines will use a filename format of:

 
3. ![](readme_images/image6.png)
4. Select a TensorRT Engine from the unet_name dropdown
5. Dynamic Engines will use a filename format of:

1. dyn-b-min-max-opt-h-min-max-opt-w-min-max-opt
2. dyn=dynamic, b=batch size, h=height, w=width

 
1. dyn-b-min-max-opt-h-min-max-opt-w-min-max-opt
2. dyn=dynamic, b=batch size, h=height, w=width

6. Static Engine will use a filename format of:
6. Static Engine will use a filename format of:

 

1. stat-b-opt-h-opt-w-opt
2. stat=static, b=batch size, h=height, w=width

 
1. stat-b-opt-h-opt-w-opt
2. stat=static, b=batch size, h=height, w=width

7. ![](readme_images/image8.png)
8. The model_type must match the model type of the TensorRT engine.
9. ![](readme_images/image10.png)
7. ![](readme_images/image8.png)
8. The model_type must match the model type of the TensorRT engine.
9. ![](readme_images/image10.png)
10. The CLIP and VAE for the workflow will need to be utilized from the
original model checkpoint, the MODEL output from the TensorRT Loader
will be connected to the Sampler.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui_tensorrt"
description = "TensorRT Node for ComfyUI\nThis node enables the best performance on NVIDIA RTX™ Graphics Cards (GPUs) for Stable Diffusion by leveraging NVIDIA TensorRT."
version = "0.1.1"
version = "0.1.2"
license = "LICENSE"
dependencies = [
"tensorrt>=10.0.1",
Expand Down
44 changes: 28 additions & 16 deletions tensorrt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
{".engine"},
)


class TQDMProgressMonitor(trt.IProgressMonitor):
def __init__(self):
trt.IProgressMonitor.__init__(self)
Expand Down Expand Up @@ -93,14 +94,18 @@ def step_complete(self, phase_name, step):
except KeyboardInterrupt:
# There is no need to propagate this exception to TensorRT. We can simply cancel the build.
return False


class TRT_MODEL_CONVERSION_BASE:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.temp_dir = folder_paths.get_temp_directory()
self.timing_cache_path = os.path.normpath(
os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt"))
os.path.join(
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt"
)
)
)

RETURN_TYPES = ()
Expand Down Expand Up @@ -150,24 +155,25 @@ def _convert(
is_static: bool,
):
output_onnx = os.path.normpath(
os.path.join(
os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx"
)
os.path.join(self.temp_dir, str(time.time()), "model.onnx")
)

comfy.model_management.unload_all_models()
comfy.model_management.load_models_gpu([model], force_patch_weights=True)
unet = model.model.diffusion_model

context_dim = model.model.model_config.unet_config.get("context_dim", None)
context_len = 77
context_len_min = context_len

if context_dim is None: #SD3
context_embedder_config = model.model.model_config.unet_config.get("context_embedder_config", None)
if context_dim is None: # SD3
context_embedder_config = model.model.model_config.unet_config.get(
"context_embedder_config", None
)
if context_embedder_config is not None:
context_dim = context_embedder_config.get("params", {}).get("in_features", None)
context_len = 154 #NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77
context_dim = context_embedder_config.get("params", {}).get(
"in_features", None
)
context_len = 154 # NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77

if context_dim is not None:
input_names = ["x", "timesteps", "context"]
Expand All @@ -179,7 +185,7 @@ def _convert(
"context": {0: "batch", 1: "num_embeds"},
}

transformer_options = model.model_options['transformer_options'].copy()
transformer_options = model.model_options["transformer_options"].copy()
if model.model.model_config.unet_config.get(
"use_temporal_resblock", False
): # SVD
Expand All @@ -205,7 +211,13 @@ def forward(self, x, timesteps, context, y):
unet = svd_unet
context_len_min = context_len = 1
else:

class UNET(torch.nn.Module):
def __init__(self, unet, opts):
super().__init__()
self.unet = unet
self.transformer_options = opts

def forward(self, x, timesteps, context, y=None):
return self.unet(
x,
Expand All @@ -214,10 +226,8 @@ def forward(self, x, timesteps, context, y=None):
y,
transformer_options=self.transformer_options,
)
_unet = UNET()
_unet.unet = unet
_unet.transformer_options = transformer_options
unet = _unet

unet = UNET(unet, transformer_options)

input_channels = model.model.model_config.unet_config.get("in_channels")

Expand Down Expand Up @@ -304,7 +314,9 @@ def forward(self, x, timesteps, context, y=None):
profile.set_shape(input_names[k], min_shape, opt_shape, max_shape)

# Encode shapes to filename
encode = lambda a: ".".join(map(lambda x: str(x), a))
def encode(a):
return ".".join(map(str, a))

prefix_encode += "{}#{}#{}#{};".format(
input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape)
)
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(self, engine_path):
self.engine = runtime.deserialize_cuda_engine(f.read())
self.context = self.engine.create_execution_context()
self.dtype = torch.float16
self.stream = torch.cuda.Stream()

def set_bindings_shape(self, inputs, split_batch):
for k in inputs:
Expand Down Expand Up @@ -91,12 +90,13 @@ def __call__(self, x, timesteps, context, y=None, control=None, transformer_opti
dtype=trt_datatype_to_torch(self.engine.get_tensor_dtype(output_binding_name)))
model_inputs_converted[output_binding_name] = out

stream = torch.cuda.default_stream(x.device)
for i in range(curr_split_batch):
for k in model_inputs_converted:
x = model_inputs_converted[k]
self.context.set_tensor_address(k, x[(x.shape[0] // curr_split_batch) * i:].data_ptr())
self.context.execute_async_v3(stream_handle=self.stream.cuda_stream)
self.stream.synchronize()
self.context.execute_async_v3(stream_handle=stream.cuda_stream)
stream.synchronize()
return out

def load_state_dict(self, sd, strict=False):
Expand Down
Loading