diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 5ab83398..e2eed1a8 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -44,8 +44,8 @@ def stable_diffusion_2( prompts. Args: - model_name (str, optional): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'. - pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. + model_name (str): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'. + pretrained (bool): Whether to load pretrained weights. Defaults to True. prediction_type (str): The type of prediction to use. Must be one of 'sample', 'epsilon', or 'v_prediction'. Default: `epsilon`. train_metrics (list, optional): List of metrics to compute during training. If None, defaults to @@ -54,12 +54,12 @@ def stable_diffusion_2( [MeanSquaredError(), FrechetInceptionDistance(normalize=True)]. val_guidance_scales (list, optional): List of scales to use for validation guidance. If None, defaults to [1.0, 3.0, 7.0]. - val_seed (int, optional): Seed to use for generating evaluation images. Defaults to 1138. + val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. loss_bins (list, optional): List of tuples of (min, max) values to use for loss binning. If None, defaults to [(0, 1)]. - precomputed_latents (bool, optional): Whether to use precomputed latents. Defaults to False. - encode_latents_in_fp16 (bool, optional): Whether to encode latents in fp16. Defaults to True. - fsdp (bool, optional): Whether to use FSDP. Defaults to True. + precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. + encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. + fsdp (bool): Whether to use FSDP. Defaults to True. """ if train_metrics is None: train_metrics = [MeanSquaredError()] @@ -144,13 +144,14 @@ def stable_diffusion_xl( prompts. Currently uses UNet and VAE config from SDXL, but text encoder/tokenizer from SD2. Args: - model_name (str, optional): Name of the model to load. Determines the text encoder, tokenizer, + model_name (str): Name of the model to load. Determines the text encoder, tokenizer, and noise scheduler. Defaults to 'stabilityai/stable-diffusion-2-base'. - unet_model_name (str, optional): Name of the UNet model to load. Defaults to + unet_model_name (str): Name of the UNet model to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. - vae_model_name (str, optional): Name of the VAE model to load. Defaults to - 'madebyollin/sdxl-vae-fp16-fix'. - pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. + vae_model_name (str): Name of the VAE model to load. Defaults to + 'madebyollin/sdxl-vae-fp16-fix' as the official VAE checkpoint (from + 'stabilityai/stable-diffusion-xl-base-1.0') is not compatible with fp16. + pretrained (bool): Whether to load pretrained weights. Defaults to True. prediction_type (str): The type of prediction to use. Must be one of 'sample', 'epsilon', or 'v_prediction'. Default: `epsilon`. train_metrics (list, optional): List of metrics to compute during training. If None, defaults to @@ -159,12 +160,12 @@ def stable_diffusion_xl( [MeanSquaredError(), FrechetInceptionDistance(normalize=True)]. val_guidance_scales (list, optional): List of scales to use for validation guidance. If None, defaults to [1.0, 3.0, 7.0]. - val_seed (int, optional): Seed to use for generating evaluation images. Defaults to 1138. + val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. loss_bins (list, optional): List of tuples of (min, max) values to use for loss binning. If None, defaults to [(0, 1)]. - precomputed_latents (bool, optional): Whether to use precomputed latents. Defaults to False. - encode_latents_in_fp16 (bool, optional): Whether to encode latents in fp16. Defaults to True. - fsdp (bool, optional): Whether to use FSDP. Defaults to True. + precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. + encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. + fsdp (bool): Whether to use FSDP. Defaults to True. """ if train_metrics is None: train_metrics = [MeanSquaredError()] @@ -180,7 +181,7 @@ def stable_diffusion_xl( metric.requires_grad_(False) if pretrained: - unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet') + raise NotImplementedError('Full SDXL pipeline not implemented yet.') else: config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet') # Currently not doing micro-conditioning, so set config appropriately