Skip to content

Commit

Permalink
Add Flux model support for InstantX style controlnet residuals (#4444)
Browse files Browse the repository at this point in the history
* Add Flux model support for InstantX style controlnet residuals

* Refactor Flux controlnet residual step to a separate method

* Rollback minor change

* New format for applying controlnet residuals: input->double_blocks, output->single_blocks

* Adjust XLabs Flux controlnet to fit new syntax of applying Flux controlnet residuals

* Remove unnecessary import and minor style change
  • Loading branch information
EeroHeikkinen authored Aug 18, 2024
1 parent 310ad09 commit e68763f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
2 changes: 1 addition & 1 deletion comfy/ldm/flux/controlnet_xlabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def forward_orig(
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)

return {"output": (controlnet_block_res_samples * 10)[:19]}
return {"input": (controlnet_block_res_samples * 10)[:19]}

def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
hint = hint * 2.0 - 1.0
Expand Down
23 changes: 16 additions & 7 deletions comfy/ldm/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,28 @@ def forward_orig(
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)

for i in range(len(self.double_blocks)):
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
for i, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)

if control is not None: #Controlnet
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add

img = torch.cat((txt, img), 1)

for i, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe)

if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img += add
img[:, txt.shape[1] :, ...] += add

img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]

img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
Expand Down

0 comments on commit e68763f

Please sign in to comment.