Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and enhancements for alternate img2img script for stable diffusion XL #16761

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions scripts/img2imgalt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
import torch
import k_diffusion as K

# Debugging notes - the original method apply_model is being called for sd1.5 is in modules.sd_hijack_utils and is ldm.models.diffusion.ddpm.LatentDiffusion
# For sdxl - OpenAIWrapper will be called, which will call the underlying diffusion_model


def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
x = p.init_latent

Expand All @@ -30,15 +34,25 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):

x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigmas[i] * s_in] * 2)
cond_in = torch.cat([uncond, cond])

if shared.sd_model.is_sdxl:
cond_tensor = cond['crossattn']
uncond_tensor = uncond['crossattn']
cond_in = torch.cat([uncond_tensor, cond_tensor])
else:
cond_in = torch.cat([uncond, cond])

image_conditioning = torch.cat([p.image_conditioning] * 2)
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}

c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
t = dnw.sigma_to_t(sigma_in)

eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
if shared.sd_model.is_sdxl:
eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["c_crossattn"][0]} )
else:
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)

denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)

denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
Expand All @@ -64,6 +78,13 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):

# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
if shared.sd_model.is_sdxl:
cond_tensor = cond['crossattn']
uncond_tensor = uncond['crossattn']
cond_in = torch.cat([uncond_tensor, cond_tensor])
else:
cond_in = torch.cat([uncond, cond])

arrmansa marked this conversation as resolved.
Show resolved Hide resolved
x = p.init_latent

s_in = x.new_ones([x.shape[0]])
Expand All @@ -82,7 +103,14 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):

x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
cond_in = torch.cat([uncond, cond])


if shared.sd_model.is_sdxl:
cond_tensor = cond['crossattn']
uncond_tensor = uncond['crossattn']
cond_in = torch.cat([uncond_tensor, cond_tensor])
else:
cond_in = torch.cat([uncond, cond])

image_conditioning = torch.cat([p.image_conditioning] * 2)
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
Expand All @@ -94,7 +122,12 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
else:
t = dnw.sigma_to_t(sigma_in)

eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)

if shared.sd_model.is_sdxl:
eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["c_crossattn"][0]} )
else:
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)

denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)

denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
Expand Down
Loading