Skip to content

Commit

Permalink
Add option to set the prediction type for stable diffusion (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Jul 17, 2023
1 parent eb98c27 commit 7fb5792
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 15 deletions.
13 changes: 12 additions & 1 deletion diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
def stable_diffusion_2(
model_name: str = 'stabilityai/stable-diffusion-2-base',
pretrained: bool = True,
prediction_type: str = 'epsilon',
train_metrics: Optional[List] = None,
val_metrics: Optional[List] = None,
val_guidance_scales: Optional[List] = None,
Expand All @@ -45,6 +46,8 @@ def stable_diffusion_2(
Args:
model_name (str, optional): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'.
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
prediction_type (str): The type of prediction to use. Must be one of 'sample',
'epsilon', or 'v_prediction'. Default: `epsilon`.
train_metrics (list, optional): List of metrics to compute during training. If None, defaults to
[MeanSquaredError()].
val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to
Expand Down Expand Up @@ -86,7 +89,14 @@ def stable_diffusion_2(

tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer')
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder='scheduler')
inference_noise_scheduler = DDIMScheduler.from_pretrained(model_name, subfolder='scheduler')
inference_noise_scheduler = DDIMScheduler(num_train_timesteps=noise_scheduler.config.num_train_timesteps,
beta_start=noise_scheduler.config.beta_start,
beta_end=noise_scheduler.config.beta_end,
beta_schedule=noise_scheduler.config.beta_schedule,
trained_betas=noise_scheduler.config.trained_betas,
clip_sample=noise_scheduler.config.clip_sample,
set_alpha_to_one=noise_scheduler.config.set_alpha_to_one,
prediction_type=prediction_type)

model = StableDiffusion(
unet=unet,
Expand All @@ -95,6 +105,7 @@ def stable_diffusion_2(
tokenizer=tokenizer,
noise_scheduler=noise_scheduler,
inference_noise_scheduler=inference_noise_scheduler,
prediction_type=prediction_type,
train_metrics=train_metrics,
val_metrics=val_metrics,
val_guidance_scales=val_guidance_scales,
Expand Down
39 changes: 25 additions & 14 deletions diffusion/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@ class StableDiffusion(ComposerModel):
num_images_per_prompt (int): How many images to generate per prompt
for evaluation. Default: `1`.
loss_fn (torch.nn.Module): torch loss function. Default: `F.mse_loss`.
prediction_type (str): `epsilon` or `v_prediction`. `v_prediction` is
used in parts of the stable diffusion v2.1 training process.
See https://arxiv.org/pdf/2202.00512.pdf.
Default: `None` (uses whatever the pretrained model used)
prediction_type (str): The type of prediction to use. Must be one of 'sample',
'epsilon', or 'v_prediction'. Default: `epsilon`.
train_metrics (list): List of torchmetrics to calculate during training.
Default: `None`.
val_metrics (list): List of torchmetrics to calculate during validation.
Expand Down Expand Up @@ -74,6 +72,7 @@ def __init__(self,
noise_scheduler,
inference_noise_scheduler,
loss_fn=F.mse_loss,
prediction_type: str = 'epsilon',
train_metrics: Optional[List] = None,
val_metrics: Optional[List] = None,
val_seed: int = 1138,
Expand All @@ -91,6 +90,9 @@ def __init__(self,
self.vae = vae
self.noise_scheduler = noise_scheduler
self.loss_fn = loss_fn
self.prediction_type = prediction_type.lower()
if self.prediction_type not in ['sample', 'epsilon', 'v_prediction']:
raise ValueError(f'prediction type must be one of sample, epsilon, or v_prediction. Got {prediction_type}')
self.val_seed = val_seed
self.image_key = image_key
self.image_latents_key = image_latents_key
Expand Down Expand Up @@ -178,9 +180,18 @@ def forward(self, batch):
# Add noise to the inputs (forward diffusion)
noise = torch.randn_like(latents)
noised_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)

# Generate the targets
if self.prediction_type == 'epsilon':
targets = noise
elif self.prediction_type == 'sample':
targets = latents
elif self.prediction_type == 'v_prediction':
targets = self.noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(
f'prediction type must be one of sample, epsilon, or v_prediction. Got {self.prediction_type}')
# Forward through the model
return self.unet(noised_latents, timesteps, conditioning)['sample'], noise, timesteps
return self.unet(noised_latents, timesteps, conditioning)['sample'], targets, timesteps

def loss(self, outputs, batch):
"""Loss between unet output and added noise, typically mse."""
Expand All @@ -192,7 +203,7 @@ def eval_forward(self, batch, outputs=None):
if outputs is not None:
return outputs
# Get unet outputs
unet_out, noise, timesteps = self.forward(batch)
unet_out, targets, timesteps = self.forward(batch)
# Sample images from the prompts in the batch
prompts = batch[self.text_key]
height, width = batch[self.image_key].shape[-2], batch[self.image_key].shape[-1]
Expand All @@ -205,7 +216,7 @@ def eval_forward(self, batch, outputs=None):
seed=self.val_seed,
progress_bar=False)
generated_images[guidance_scale] = gen_images
return unet_out, noise, timesteps, generated_images
return unet_out, targets, timesteps, generated_images

def get_metrics(self, is_train: bool = False):
if is_train:
Expand Down Expand Up @@ -363,16 +374,16 @@ def generate(
latent_model_input = latents

latent_model_input = self.inference_scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# Model prediction
pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

if do_classifier_free_guidance:
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform guidance. Note this is only techincally correct for prediction_type 'epsilon'
pred_uncond, pred_text = pred.chunk(2)
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.inference_scheduler.step(noise_pred, t, latents, generator=rng_generator).prev_sample
latents = self.inference_scheduler.step(pred, t, latents, generator=rng_generator).prev_sample

# We now use the vae to decode the generated latents back into the image.
# scale and decode the image latents with vae
Expand Down

0 comments on commit 7fb5792

Please sign in to comment.