From e68763f40c32da0f5fc921f0bd7112fb6319dc92 Mon Sep 17 00:00:00 2001 From: Xrvk Date: Sun, 18 Aug 2024 05:58:23 +0300 Subject: [PATCH] Add Flux model support for InstantX style controlnet residuals (#4444) * Add Flux model support for InstantX style controlnet residuals * Refactor Flux controlnet residual step to a separate method * Rollback minor change * New format for applying controlnet residuals: input->double_blocks, output->single_blocks * Adjust XLabs Flux controlnet to fit new syntax of applying Flux controlnet residuals * Remove unnecessary import and minor style change --- comfy/ldm/flux/controlnet_xlabs.py | 2 +- comfy/ldm/flux/model.py | 23 ++++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/flux/controlnet_xlabs.py b/comfy/ldm/flux/controlnet_xlabs.py index 3f40021b2e2..5d700f16c9f 100644 --- a/comfy/ldm/flux/controlnet_xlabs.py +++ b/comfy/ldm/flux/controlnet_xlabs.py @@ -82,7 +82,7 @@ def forward_orig( block_res_sample = controlnet_block(block_res_sample) controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) - return {"output": (controlnet_block_res_samples * 10)[:19]} + return {"input": (controlnet_block_res_samples * 10)[:19]} def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): hint = hint * 2.0 - 1.0 diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index b5373540a84..eb7767ca98a 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -114,19 +114,28 @@ def forward_orig( ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) - for i in range(len(self.double_blocks)): - img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe) + for i, block in enumerate(self.double_blocks): + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) - if control is not None: #Controlnet + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + img += add + + img = torch.cat((txt, img), 1) + + for i, block in enumerate(self.single_blocks): + img = block(img, vec=vec, pe=pe) + + if control is not None: # Controlnet control_o = control.get("output") if i < len(control_o): add = control_o[i] if add is not None: - img += add + img[:, txt.shape[1] :, ...] += add - img = torch.cat((txt, img), 1) - for block in self.single_blocks: - img = block(img, vec=vec, pe=pe) img = img[:, txt.shape[1] :, ...] img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)