diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 74caf510..6990710f 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -28,6 +28,7 @@ def stable_diffusion_2( model_name: str = 'stabilityai/stable-diffusion-2-base', pretrained: bool = True, + prediction_type: str = 'epsilon', train_metrics: Optional[List] = None, val_metrics: Optional[List] = None, val_guidance_scales: Optional[List] = None, @@ -45,6 +46,8 @@ def stable_diffusion_2( Args: model_name (str, optional): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'. pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. + prediction_type (str): The type of prediction to use. Must be one of 'sample', + 'epsilon', or 'v_prediction'. Default: `epsilon`. train_metrics (list, optional): List of metrics to compute during training. If None, defaults to [MeanSquaredError()]. val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to @@ -86,7 +89,14 @@ def stable_diffusion_2( tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder='scheduler') - inference_noise_scheduler = DDIMScheduler.from_pretrained(model_name, subfolder='scheduler') + inference_noise_scheduler = DDIMScheduler(num_train_timesteps=noise_scheduler.config.num_train_timesteps, + beta_start=noise_scheduler.config.beta_start, + beta_end=noise_scheduler.config.beta_end, + beta_schedule=noise_scheduler.config.beta_schedule, + trained_betas=noise_scheduler.config.trained_betas, + clip_sample=noise_scheduler.config.clip_sample, + set_alpha_to_one=noise_scheduler.config.set_alpha_to_one, + prediction_type=prediction_type) model = StableDiffusion( unet=unet, @@ -95,6 +105,7 @@ def stable_diffusion_2( tokenizer=tokenizer, noise_scheduler=noise_scheduler, inference_noise_scheduler=inference_noise_scheduler, + prediction_type=prediction_type, train_metrics=train_metrics, val_metrics=val_metrics, val_guidance_scales=val_guidance_scales, diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index b3864492..01688c04 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -37,10 +37,8 @@ class StableDiffusion(ComposerModel): num_images_per_prompt (int): How many images to generate per prompt for evaluation. Default: `1`. loss_fn (torch.nn.Module): torch loss function. Default: `F.mse_loss`. - prediction_type (str): `epsilon` or `v_prediction`. `v_prediction` is - used in parts of the stable diffusion v2.1 training process. - See https://arxiv.org/pdf/2202.00512.pdf. - Default: `None` (uses whatever the pretrained model used) + prediction_type (str): The type of prediction to use. Must be one of 'sample', + 'epsilon', or 'v_prediction'. Default: `epsilon`. train_metrics (list): List of torchmetrics to calculate during training. Default: `None`. val_metrics (list): List of torchmetrics to calculate during validation. @@ -74,6 +72,7 @@ def __init__(self, noise_scheduler, inference_noise_scheduler, loss_fn=F.mse_loss, + prediction_type: str = 'epsilon', train_metrics: Optional[List] = None, val_metrics: Optional[List] = None, val_seed: int = 1138, @@ -91,6 +90,9 @@ def __init__(self, self.vae = vae self.noise_scheduler = noise_scheduler self.loss_fn = loss_fn + self.prediction_type = prediction_type.lower() + if self.prediction_type not in ['sample', 'epsilon', 'v_prediction']: + raise ValueError(f'prediction type must be one of sample, epsilon, or v_prediction. Got {prediction_type}') self.val_seed = val_seed self.image_key = image_key self.image_latents_key = image_latents_key @@ -178,9 +180,18 @@ def forward(self, batch): # Add noise to the inputs (forward diffusion) noise = torch.randn_like(latents) noised_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) - + # Generate the targets + if self.prediction_type == 'epsilon': + targets = noise + elif self.prediction_type == 'sample': + targets = latents + elif self.prediction_type == 'v_prediction': + targets = self.noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError( + f'prediction type must be one of sample, epsilon, or v_prediction. Got {self.prediction_type}') # Forward through the model - return self.unet(noised_latents, timesteps, conditioning)['sample'], noise, timesteps + return self.unet(noised_latents, timesteps, conditioning)['sample'], targets, timesteps def loss(self, outputs, batch): """Loss between unet output and added noise, typically mse.""" @@ -192,7 +203,7 @@ def eval_forward(self, batch, outputs=None): if outputs is not None: return outputs # Get unet outputs - unet_out, noise, timesteps = self.forward(batch) + unet_out, targets, timesteps = self.forward(batch) # Sample images from the prompts in the batch prompts = batch[self.text_key] height, width = batch[self.image_key].shape[-2], batch[self.image_key].shape[-1] @@ -205,7 +216,7 @@ def eval_forward(self, batch, outputs=None): seed=self.val_seed, progress_bar=False) generated_images[guidance_scale] = gen_images - return unet_out, noise, timesteps, generated_images + return unet_out, targets, timesteps, generated_images def get_metrics(self, is_train: bool = False): if is_train: @@ -363,16 +374,16 @@ def generate( latent_model_input = latents latent_model_input = self.inference_scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # Model prediction + pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample if do_classifier_free_guidance: - # perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # perform guidance. Note this is only techincally correct for prediction_type 'epsilon' + pred_uncond, pred_text = pred.chunk(2) + pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.inference_scheduler.step(noise_pred, t, latents, generator=rng_generator).prev_sample + latents = self.inference_scheduler.step(pred, t, latents, generator=rng_generator).prev_sample # We now use the vae to decode the generated latents back into the image. # scale and decode the image latents with vae