Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to set the prediction type for stable diffusion #49

Merged
merged 5 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
def stable_diffusion_2(
model_name: str = 'stabilityai/stable-diffusion-2-base',
pretrained: bool = True,
prediction_type: str = 'epsilon',
train_metrics: Optional[List] = None,
val_metrics: Optional[List] = None,
val_guidance_scales: Optional[List] = None,
Expand All @@ -45,6 +46,8 @@ def stable_diffusion_2(
Args:
model_name (str, optional): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'.
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
prediction_type (str): The type of prediction to use. Must be one of 'sample',
'epsilon', or 'v_prediction'. Default: `epsilon`.
train_metrics (list, optional): List of metrics to compute during training. If None, defaults to
[MeanSquaredError()].
val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to
Expand Down Expand Up @@ -86,7 +89,15 @@ def stable_diffusion_2(

tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer')
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder='scheduler')
inference_noise_scheduler = DDIMScheduler.from_pretrained(model_name, subfolder='scheduler')
inference_noise_scheduler = DDIMScheduler(
num_train_timesteps=noise_scheduler.config.num_train_timesteps,
beta_start=noise_scheduler.config.beta_start,
beta_end=noise_scheduler.config.beta_end,
beta_schedule=noise_scheduler.config.beta_schedule,
trained_betas=noise_scheduler.config.trained_betas,
clip_sample=noise_scheduler.config.clip_sample,
set_alpha_to_one=noise_scheduler.config.set_alpha_to_one,
prediction_type=prediction_type)

model = StableDiffusion(
unet=unet,
Expand All @@ -95,6 +106,7 @@ def stable_diffusion_2(
tokenizer=tokenizer,
noise_scheduler=noise_scheduler,
inference_noise_scheduler=inference_noise_scheduler,
prediction_type=prediction_type,
train_metrics=train_metrics,
val_metrics=val_metrics,
val_guidance_scales=val_guidance_scales,
Expand Down
39 changes: 25 additions & 14 deletions diffusion/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@ class StableDiffusion(ComposerModel):
num_images_per_prompt (int): How many images to generate per prompt
for evaluation. Default: `1`.
loss_fn (torch.nn.Module): torch loss function. Default: `F.mse_loss`.
prediction_type (str): `epsilon` or `v_prediction`. `v_prediction` is
used in parts of the stable diffusion v2.1 training process.
See https://arxiv.org/pdf/2202.00512.pdf.
Default: `None` (uses whatever the pretrained model used)
prediction_type (str): The type of prediction to use. Must be one of 'sample',
'epsilon', or 'v_prediction'. Default: `epsilon`.
train_metrics (list): List of torchmetrics to calculate during training.
Default: `None`.
val_metrics (list): List of torchmetrics to calculate during validation.
Expand Down Expand Up @@ -74,6 +72,7 @@ def __init__(self,
noise_scheduler,
inference_noise_scheduler,
loss_fn=F.mse_loss,
prediction_type: str = 'epsilon',
train_metrics: Optional[List] = None,
val_metrics: Optional[List] = None,
val_seed: int = 1138,
Expand All @@ -91,6 +90,9 @@ def __init__(self,
self.vae = vae
self.noise_scheduler = noise_scheduler
self.loss_fn = loss_fn
if prediction_type.lower() not in ['sample', 'epsilon', 'v_prediction']:
raise ValueError(f'prediction type must be one of sample, epsilon, or v_prediction. Got {prediction_type}')
self.prediction_type = prediction_type.lower()
coryMosaicML marked this conversation as resolved.
Show resolved Hide resolved
self.val_seed = val_seed
self.image_key = image_key
self.image_latents_key = image_latents_key
Expand Down Expand Up @@ -178,9 +180,18 @@ def forward(self, batch):
# Add noise to the inputs (forward diffusion)
noise = torch.randn_like(latents)
noised_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)

# Generate the targets
if self.prediction_type == 'epsilon':
targets = noise
elif self.prediction_type == 'sample':
targets = latents
elif self.prediction_type == 'v_prediction':
targets = self.noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(
f'prediction type must be one of sample, epsilon, or v_prediction. Got {self.prediction_type}')
# Forward through the model
return self.unet(noised_latents, timesteps, conditioning)['sample'], noise, timesteps
return self.unet(noised_latents, timesteps, conditioning)['sample'], targets, timesteps

def loss(self, outputs, batch):
"""Loss between unet output and added noise, typically mse."""
Expand All @@ -192,7 +203,7 @@ def eval_forward(self, batch, outputs=None):
if outputs is not None:
return outputs
# Get unet outputs
unet_out, noise, timesteps = self.forward(batch)
unet_out, targets, timesteps = self.forward(batch)
# Sample images from the prompts in the batch
prompts = batch[self.text_key]
height, width = batch[self.image_key].shape[-2], batch[self.image_key].shape[-1]
Expand All @@ -205,7 +216,7 @@ def eval_forward(self, batch, outputs=None):
seed=self.val_seed,
progress_bar=False)
generated_images[guidance_scale] = gen_images
return unet_out, noise, timesteps, generated_images
return unet_out, targets, timesteps, generated_images

def get_metrics(self, is_train: bool = False):
if is_train:
Expand Down Expand Up @@ -363,16 +374,16 @@ def generate(
latent_model_input = latents

latent_model_input = self.inference_scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# Model prediction
pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

if do_classifier_free_guidance:
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform guidance. Note this is only techincally correct for prediction_type 'epsilon'
pred_uncond, pred_text = pred.chunk(2)
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.inference_scheduler.step(noise_pred, t, latents, generator=rng_generator).prev_sample
latents = self.inference_scheduler.step(pred, t, latents, generator=rng_generator).prev_sample

# We now use the vae to decode the generated latents back into the image.
# scale and decode the image latents with vae
Expand Down
Loading