Skip to content

Fix img2img with dpm++ karras when there are duped timesteps #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: pai/main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -561,40 +561,31 @@ def check_inputs(
)

def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
# get the original timestep using init_timestep
if denoising_start is None:
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
else:
t_start = 0
denoising_start = 1 - strength

timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
timesteps = self.scheduler.timesteps

# Strength is irrelevant if we directly request a timestep to start at;
# that is, strength is determined by the denoising_start instead.
if denoising_start is not None:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)

num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
# if the scheduler is a 2nd order scheduler we might have to do +1
# because `num_inference_steps` might be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1

# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps

return timesteps, num_inference_steps - t_start
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
# if the scheduler is a 2nd order scheduler we might have to do +1
# because `num_inference_steps` might be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1

# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps

def prepare_latents(
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,8 +773,8 @@ def _init_step_index(self, timestep):
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
# elif len(index_candidates) > 1:
# step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()

Expand Down Expand Up @@ -887,7 +887,7 @@ def add_noise(
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = [(schedule_timesteps == t).nonzero().tolist()[0] for t in timesteps]

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
Expand Down