diff --git a/README.md b/README.md index 7a101c1..3b55d81 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ Early days of a lightweight MLIR Python frontend with support for PyTorch (throu Just ```shell -pip install - requirements.txt +pip install -r requirements.txt pip install . --no-build-isolation ``` diff --git a/examples/unet.py b/examples/unet.py index ca51b1d..268fef9 100644 --- a/examples/unet.py +++ b/examples/unet.py @@ -1,27 +1,59 @@ -import pi -from pi import nn -from pi.mlir.utils import pipile -from pi.utils.annotations import annotate_args +import torch +import numpy as np + from pi.models.unet import UNet2DConditionModel +import torch_mlir + +unet = UNet2DConditionModel( + **{ + "block_out_channels": (32, 64), + "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), + "cross_attention_dim": 32, + "attention_head_dim": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 2, + "sample_size": 32, + } +) +unet.eval() + +batch_size = 4 +num_channels = 4 +sizes = (32, 32) + + +def floats_tensor(shape, scale=1.0, rng=None, name=None): + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(np.random.random() * scale) + + return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() + +noise = floats_tensor((batch_size, num_channels) + sizes) +time_step = torch.tensor([10]) +encoder_hidden_states = floats_tensor((batch_size, 4, 32)) -class MyUNet(nn.Module): - def __init__(self): - super().__init__() - self.unet = UNet2DConditionModel() +output = unet(noise, time_step, encoder_hidden_states) +print(output) - @annotate_args( - [ - None, - ([-1, -1, -1, -1], pi.float32, True), - ] - ) - def forward(self, x): - y = self.resnet(x) - return y +traced = torch.jit.trace(unet, (noise, time_step, encoder_hidden_states), strict=False) +frozen = torch.jit.freeze(traced) +print(frozen.graph) -test_module = MyUNet() -x = pi.randn((1, 3, 64, 64)) -mlir_module = pipile(test_module, example_args=(x,)) -print(mlir_module) +module = torch_mlir.compile( + frozen, + (noise, time_step, encoder_hidden_states), + use_tracing=True, + output_type=torch_mlir.OutputType.RAW, +) +with open("unet.mlir", "w") as f: + f.write(str(module)) diff --git a/pi/models/unet.py b/pi/models/unet.py new file mode 100644 index 0000000..7491e73 --- /dev/null +++ b/pi/models/unet.py @@ -0,0 +1,11078 @@ +import importlib +import json +import logging +import sys +from pathlib import Path, PosixPath + +logger = logging.getLogger(__name__) + +import functools +import operator as op +import os +from dataclasses import dataclass, fields +from collections import OrderedDict, defaultdict +from copy import deepcopy +from functools import partial +from torch import Tensor, device +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from uuid import uuid4 +import inspect +import math +import numpy as np +import re +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +import warnings + + +class MakeCutouts(nn.Module): + def __init__(self, cut_size, cut_power=1.0): + super().__init__() + self.cut_size = cut_size + self.cut_power = cut_power + + def forward(self, pixel_values, num_cutouts): + sideY, sideX = pixel_values.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + for _ in range(num_cutouts): + size = int( + torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size + ) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = pixel_values[ + :, :, offsety : offsety + size, offsetx : offsetx + size + ] + cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) + return torch.cat(cutouts) + + +class DiffusionUncond(nn.Module): + def __init__(self, global_args): + super().__init__() + self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4) + self.diffusion_ema = deepcopy(self.diffusion) + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + +class AttnProcsLayers(torch.nn.Module): + def __init__(self, state_dict: "Dict[str, torch.Tensor]"): + super().__init__() + self.layers = torch.nn.ModuleList(state_dict.values()) + self.mapping = {k: v for k, v in enumerate(state_dict.keys())} + self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} + + def map_to(module, state_dict, *args, **kwargs): + new_state_dict = {} + for key, value in state_dict.items(): + num = int(key.split(".")[1]) + new_key = key.replace(f"layers.{num}", module.mapping[num]) + new_state_dict[new_key] = value + return new_state_dict + + def map_from(module, state_dict, *args, **kwargs): + all_keys = list(state_dict.keys()) + for key in all_keys: + replace_key = key.split(".processor")[0] + ".processor" + new_key = key.replace( + replace_key, f"layers.{module.rev_mapping[replace_key]}" + ) + state_dict[new_key] = state_dict[key] + del state_dict[key] + + self._register_state_dict_hook(map_to) + self._register_load_state_dict_pre_hook(map_from, with_module=True) + + +def is_xformers_available(): + return _xformers_available + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Uses three q, k, v linear layers to compute attention. + + Parameters: + channels (`int`): The number of channels in the input and output. + num_head_channels (`int`, *optional*): + The number of channels in each head. If None, then `num_heads` = 1. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + """ + + def __init__( + self, + channels: "int", + num_head_channels: "Optional[int]" = None, + norm_num_groups: "int" = 32, + rescale_output_factor: "float" = 1.0, + eps: "float" = 1e-05, + ): + super().__init__() + self.channels = channels + self.num_heads = ( + channels // num_head_channels if num_head_channels is not None else 1 + ) + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm( + num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True + ) + self.query = nn.Linear(channels, channels) + self.key = nn.Linear(channels, channels) + self.value = nn.Linear(channels, channels) + self.rescale_output_factor = rescale_output_factor + self.proj_attn = nn.Linear(channels, channels, 1) + self._use_memory_efficient_attention_xformers = False + self._attention_op = None + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size * head_size, seq_len, dim // head_size + ) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size + ) + return tensor + + def set_use_memory_efficient_attention_xformers( + self, + use_memory_efficient_attention_xformers: "bool", + attention_op: "Optional[Callable]" = None, + ): + if use_memory_efficient_attention_xformers: + if not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only available for GPU " + ) + else: + try: + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self._use_memory_efficient_attention_xformers = ( + use_memory_efficient_attention_xformers + ) + self._attention_op = attention_op + + def forward(self, hidden_states): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + hidden_states = self.group_norm(hidden_states) + hidden_states = hidden_states.view(batch, channel, height * width).transpose( + 1, 2 + ) + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + scale = 1 / math.sqrt(self.channels / self.num_heads) + query_proj = self.reshape_heads_to_batch_dim(query_proj) + key_proj = self.reshape_heads_to_batch_dim(key_proj) + value_proj = self.reshape_heads_to_batch_dim(value_proj) + if self._use_memory_efficient_attention_xformers: + hidden_states = xformers.ops.memory_efficient_attention( + query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op + ) + hidden_states = hidden_states + else: + attention_scores = torch.baddbmm( + torch.empty( + query_proj.shape[0], + query_proj.shape[1], + key_proj.shape[1], + dtype=query_proj.dtype, + device=query_proj.device, + ), + query_proj, + key_proj.transpose(-1, -2), + beta=0, + alpha=scale, + ) + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type( + attention_scores.dtype + ) + hidden_states = torch.bmm(attention_probs, value_proj) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch, channel, height, width + ) + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + +class AdaLayerNorm(nn.Module): + """ + Norm layer modified to incorporate timestep embeddings. + """ + + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + self.emb = nn.Embedding(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x, timestep): + emb = self.linear(self.silu(self.emb(timestep))) + scale, shift = torch.chunk(emb, 2) + x = self.norm(x) * (1 + scale) + shift + return x + + +class LabelEmbedding(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + + Args: + num_classes (`int`): The number of classes. + hidden_size (`int`): The size of the vector embeddings. + dropout_prob (`float`): The probability of dropping a label. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding( + num_classes + use_cfg_embedding, hidden_size + ) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = ( + torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + ) + else: + drop_ids = torch.tensor(force_drop_ids == 1) + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if self.training and use_dropout or force_drop_ids is not None: + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: "int", + time_embed_dim: "int", + act_fn: "str" = "silu", + out_dim: "int" = None, + post_act_fn: "Optional[str]" = None, + cond_proj_dim=None, + ): + super().__init__() + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + if act_fn == "silu": + self.act = nn.SiLU() + elif act_fn == "mish": + self.act = nn.Mish() + elif act_fn == "gelu": + self.act = nn.GELU() + else: + raise ValueError( + f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'" + ) + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + if post_act_fn is None: + self.post_act = None + elif post_act_fn == "silu": + self.post_act = nn.SiLU() + elif post_act_fn == "mish": + self.post_act = nn.Mish() + elif post_act_fn == "gelu": + self.post_act = nn.GELU() + else: + raise ValueError( + f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'" + ) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + if self.act is not None: + sample = self.act(sample) + sample = self.linear_2(sample) + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +def get_timestep_embedding( + timesteps: "torch.Tensor", + embedding_dim: "int", + flip_sin_to_cos: "bool" = False, + downscale_freq_shift: "float" = 1, + scale: "float" = 1, + max_period: "int" = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + emb = scale * emb + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class Timesteps(nn.Module): + def __init__( + self, + num_channels: "int", + flip_sin_to_cos: "bool", + downscale_freq_shift: "float", + ): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class CombinedTimestepLabelEmbeddings(nn.Module): + def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): + super().__init__() + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1 + ) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim + ) + self.class_embedder = LabelEmbedding( + num_classes, embedding_dim, class_dropout_prob + ) + + def forward(self, timestep, class_labels, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj) + class_labels = self.class_embedder(class_labels) + conditioning = timesteps_emb + class_labels + return conditioning + + +class AdaLayerNormZero(nn.Module): + """ + Norm layer adaptive layer norm zero (adaLN-Zero). + """ + + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-06) + + def forward(self, x, timestep, class_labels, hidden_dtype=None): + emb = self.linear( + self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)) + ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk( + 6, dim=1 + ) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: "CrossAttention", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class CrossAttnAddedKVProcessor: + def __call__( + self, + attn: "CrossAttention", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + residual = hidden_states + hidden_states = hidden_states.view( + hidden_states.shape[0], hidden_states.shape[1], -1 + ).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim( + encoder_hidden_states_key_proj + ) + encoder_hidden_states_value_proj = attn.head_to_batch_dim( + encoder_hidden_states_value_proj + ) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + return hidden_states + + +class CrossAttnProcessor: + def __call__( + self, + attn: "CrossAttention", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4): + super().__init__() + if rank > min(in_features, out_features): + raise ValueError( + f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" + ) + self.down = nn.Linear(in_features, rank, bias=False) + self.up = nn.Linear(rank, out_features, bias=False) + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + down_hidden_states = self.down(hidden_states) + up_hidden_states = self.up(down_hidden_states) + return up_hidden_states + + +class LoRACrossAttnProcessor(nn.Module): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + super().__init__() + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_k_lora = LoRALinearLayer( + cross_attention_dim or hidden_size, hidden_size, rank + ) + self.to_v_lora = LoRALinearLayer( + cross_attention_dim or hidden_size, hidden_size, rank + ) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + + def __call__( + self, + attn: "CrossAttention", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + scale=1.0, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query) + encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora( + encoder_hidden_states + ) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora( + encoder_hidden_states + ) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora( + hidden_states + ) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LoRAXFormersCrossAttnProcessor(nn.Module): + def __init__( + self, + hidden_size, + cross_attention_dim, + rank=4, + attention_op: "Optional[Callable]" = None, + ): + super().__init__() + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + self.attention_op = attention_op + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_k_lora = LoRALinearLayer( + cross_attention_dim or hidden_size, hidden_size, rank + ) + self.to_v_lora = LoRALinearLayer( + cross_attention_dim or hidden_size, hidden_size, rank + ) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + + def __call__( + self, + attn: "CrossAttention", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + scale=1.0, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query).contiguous() + encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora( + encoder_hidden_states + ) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora( + encoder_hidden_states + ) + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora( + hidden_states + ) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class SlicedAttnAddedKVProcessor: + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__( + self, + attn: "'CrossAttention'", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + residual = hidden_states + hidden_states = hidden_states.view( + hidden_states.shape[0], hidden_states.shape[1], -1 + ).transpose(1, 2) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + encoder_hidden_states_key_proj = attn.head_to_batch_dim( + encoder_hidden_states_key_proj + ) + encoder_hidden_states_value_proj = attn.head_to_batch_dim( + encoder_hidden_states_value_proj + ) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), + device=query.device, + dtype=query.dtype, + ) + for i in range(batch_size_attention // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = ( + attention_mask[start_idx:end_idx] + if attention_mask is not None + else None + ) + attn_slice = attn.get_attention_scores( + query_slice, key_slice, attn_mask_slice + ) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + hidden_states[start_idx:end_idx] = attn_slice + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + return hidden_states + + +class SlicedAttnProcessor: + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__( + self, + attn: "CrossAttention", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), + device=query.device, + dtype=query.dtype, + ) + for i in range(batch_size_attention // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = ( + attention_mask[start_idx:end_idx] + if attention_mask is not None + else None + ) + attn_slice = attn.get_attention_scores( + query_slice, key_slice, attn_mask_slice + ) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + hidden_states[start_idx:end_idx] = attn_slice + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class XFormersCrossAttnProcessor: + def __init__(self, attention_op: "Optional[Callable]" = None): + self.attention_op = attention_op + + def __call__( + self, + attn: "CrossAttention", + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states = hidden_states + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): + deprecated_kwargs = take_from + values = () + if not isinstance(args[0], tuple): + args = (args,) + for attribute, version_name, message in args: + if version.parse(version.parse(__version__).base_version) >= version.parse( + version_name + ): + raise ValueError( + f"The deprecation tuple {attribute, version_name, message} should be removed since diffusers' version {__version__} is >= {version_name}" + ) + warning = None + if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: + values += (deprecated_kwargs.pop(attribute),) + warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." + elif hasattr(deprecated_kwargs, attribute): + values += (getattr(deprecated_kwargs, attribute),) + warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." + elif deprecated_kwargs is None: + warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." + if warning is not None: + warning = warning + " " if standard_warn else "" + warnings.warn(warning + message, FutureWarning, stacklevel=2) + if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: + call_frame = inspect.getouterframes(inspect.currentframe())[1] + filename = call_frame.filename + line_number = call_frame.lineno + function = call_frame.function + key, value = next(iter(deprecated_kwargs.items())) + raise TypeError( + f"{function} in {filename} line {line_number - 1} got an unexpected keyword argument `{key}`" + ) + if len(values) == 0: + return + elif len(values) == 1: + return values[0] + return values + + +class CrossAttention(nn.Module): + """ + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: "int", + cross_attention_dim: "Optional[int]" = None, + heads: "int" = 8, + dim_head: "int" = 64, + dropout: "float" = 0.0, + bias=False, + upcast_attention: "bool" = False, + upcast_softmax: "bool" = False, + cross_attention_norm: "bool" = False, + added_kv_proj_dim: "Optional[int]" = None, + norm_num_groups: "Optional[int]" = None, + out_bias: "bool" = True, + scale_qk: "bool" = True, + processor: "Optional['AttnProcessor']" = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.cross_attention_norm = cross_attention_norm + self.scale = dim_head ** -0.5 if scale_qk else 1.0 + self.heads = heads + self.sliceable_head_dim = heads + self.added_kv_proj_dim = added_kv_proj_dim + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm( + num_channels=inner_dim, + num_groups=norm_num_groups, + eps=1e-05, + affine=True, + ) + else: + self.group_norm = None + if cross_attention_norm: + self.norm_cross = nn.LayerNorm(cross_attention_dim) + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + if processor is None: + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and scale_qk + else CrossAttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, + use_memory_efficient_attention_xformers: "bool", + attention_op: "Optional[Callable]" = None, + ): + is_lora = hasattr(self, "processor") and isinstance( + self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor) + ) + if use_memory_efficient_attention_xformers: + if self.added_kv_proj_dim is not None: + raise NotImplementedError( + "Memory efficient attention with `xformers` is currently not supported when `self.added_kv_proj_dim` is defined." + ) + elif not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only available for GPU " + ) + else: + try: + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + if is_lora: + processor = LoRAXFormersCrossAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + processor + else: + processor = XFormersCrossAttnProcessor(attention_op=attention_op) + elif is_lora: + processor = LoRACrossAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + ) + processor.load_state_dict(self.processor.state_dict()) + processor + else: + processor = CrossAttnProcessor() + self.set_processor(processor) + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}." + ) + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = CrossAttnAddedKVProcessor() + else: + processor = CrossAttnProcessor() + self.set_processor(processor) + + def set_processor(self, processor: "'AttnProcessor'"): + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info( + f"You are removing possibly trained weights of {self.processor} with {processor}" + ) + self._modules.pop("processor") + self.processor = processor + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + **cross_attention_kwargs, + ): + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size + ) + return tensor + + def head_to_batch_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size * head_size, seq_len, dim // head_size + ) + return tensor + + def get_attention_scores(self, query, key, attention_mask=None): + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device, + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + attention_scores = torch.baddbmm( + baddbmm_input, query, key.transpose(-1, -2), beta=beta, alpha=self.scale + ) + if self.upcast_softmax: + attention_scores = attention_scores.float() + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs + return attention_probs + + def prepare_attention_mask(self, attention_mask, target_length, batch_size=None): + if batch_size is None: + deprecate( + "batch_size=None", + "0.0.15", + "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to `prepare_attention_mask` when preparing the attention_mask.", + ) + batch_size = 1 + head_size = self.heads + if attention_mask is None: + return attention_mask + if attention_mask.shape[-1] != target_length: + if attention_mask.device.type == "mps": + padding_shape = ( + attention_mask.shape[0], + attention_mask.shape[1], + target_length, + ) + padding = torch.zeros( + padding_shape, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + return attention_mask + + +class ApproximateGELU(nn.Module): + """ + The approximate form of Gaussian Error Linear Unit (GELU) + + For more details, see section 2: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, dim_in: "int", dim_out: "int"): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x): + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +class GEGLU(nn.Module): + """ + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: "int", dim_out: "int"): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + return F.gelu(gate.to(dtype=torch.float32)) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class GELU(nn.Module): + """ + GELU activation function with tanh approximation support with `approximate="tanh"`. + """ + + def __init__(self, dim_in: "int", dim_out: "int", approximate: "str" = "none"): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + self.approximate = approximate + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class FeedForward(nn.Module): + """ + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: "int", + dim_out: "Optional[int]" = None, + mult: "int" = 4, + dropout: "float" = 0.0, + activation_fn: "str" = "geglu", + final_dropout: "bool" = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + self.net = nn.ModuleList([]) + self.net.append(act_fn) + self.net.append(nn.Dropout(dropout)) + self.net.append(nn.Linear(inner_dim, dim_out)) + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class BasicTransformerBlock(nn.Module): + """ + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: "int", + num_attention_heads: "int", + attention_head_dim: "int", + dropout=0.0, + cross_attention_dim: "Optional[int]" = None, + activation_fn: "str" = "geglu", + num_embeds_ada_norm: "Optional[int]" = None, + attention_bias: "bool" = False, + only_cross_attention: "bool" = False, + upcast_attention: "bool" = False, + norm_elementwise_affine: "bool" = True, + norm_type: "str" = "layer_norm", + final_dropout: "bool" = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm_zero = ( + num_embeds_ada_norm is not None and norm_type == "ada_norm_zero" + ) + self.use_ada_layer_norm = ( + num_embeds_ada_norm is not None and norm_type == "ada_norm" + ) + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + if cross_attention_dim is not None: + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + else: + self.norm2 = None + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states + if self.only_cross_attention + else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + norm_hidden_states = self.norm3(hidden_states) + if self.use_ada_layer_norm_zero: + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + ff_output = self.ff(norm_hidden_states) + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + hidden_states = ff_output + hidden_states + return hidden_states + + +class AdaGroupNorm(nn.Module): + """ + GroupNorm layer modified to incorporate timestep embeddings. + """ + + def __init__( + self, + embedding_dim: "int", + out_dim: "int", + num_groups: "int", + act_fn: "Optional[str]" = None, + eps: "float" = 1e-05, + ): + super().__init__() + self.num_groups = num_groups + self.eps = eps + self.act = None + if act_fn == "swish": + self.act = lambda x: F.silu(x) + elif act_fn == "mish": + self.act = nn.Mish() + elif act_fn == "silu": + self.act = nn.SiLU() + elif act_fn == "gelu": + self.act = nn.GELU() + self.linear = nn.Linear(embedding_dim, out_dim * 2) + + def forward(self, x, emb): + if self.act: + emb = self.act(emb) + emb = self.linear(emb) + emb = emb[:, :, None, None] + scale, shift = emb.chunk(2, dim=1) + x = F.group_norm(x, self.num_groups, eps=self.eps) + x = x * (1 + scale) + shift + return x + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: "int", + conditioning_channels: "int" = 3, + block_out_channels: "Tuple[int]" = (16, 32, 96, 256), + ): + super().__init__() + self.conv_in = nn.Conv2d( + conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 + ) + self.blocks = nn.ModuleList([]) + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append( + nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1) + ) + self.blocks.append( + nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2) + ) + self.conv_out = zero_module( + nn.Conv2d( + block_out_channels[-1], + conditioning_embedding_channels, + kernel_size=3, + padding=1, + ) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + embedding = self.conv_out(embedding) + return embedding + + +STR_OPERATION_TO_FUNC = { + ">": op.gt, + ">=": op.ge, + "==": op.eq, + "!=": op.ne, + "<=": op.le, + "<": op.lt, +} + + +def compare_versions( + library_or_version: "Union[str, Version]", + operation: "str", + requirement_version: "str", +): + """ + Args: + Compares a library version to some requirement using a given operation. + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError( + f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}" + ) + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = parse(importlib_metadata.version(library_or_version)) + return operation(library_or_version, parse(requirement_version)) + + +def is_transformers_version(operation: "str", version: "str"): + """ + Args: + Compares the current Transformers version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _transformers_available: + return False + return compare_versions(parse(_transformers_version), operation, version) + + +def requires_backends(obj, backends): + if not isinstance(backends, (list, tuple)): + backends = [backends] + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + checks = (BACKENDS_MAPPING[backend] for backend in backends) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError("".join(failed)) + if ( + name + in [ + "VersatileDiffusionTextToImagePipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionDualGuidedPipeline", + "StableDiffusionImageVariationPipeline", + "UnCLIPPipeline", + ] + and is_transformers_version("<", "4.25.0") + ): + raise ImportError( + f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install --upgrade transformers \n```" + ) + if ( + name + in [ + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionPix2PixZeroPipeline", + ] + and is_transformers_version("<", "4.26.0") + ): + raise ImportError( + f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install --upgrade transformers \n```" + ) + + +class DummyObject(type): + """ + Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by + `requires_backend` each time a user tries to access any method of that class. + """ + + def __getattr__(cls, key): + if key.startswith("_"): + return super().__getattr__(cls, key) + requires_backends(cls, cls._backends) + + +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + for key, value in self.items(): + setattr(self, key, value) + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception( + f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." + ) + + def setdefault(self, *args, **kwargs): + raise Exception( + f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." + ) + + def pop(self, *args, **kwargs): + raise Exception( + f"You cannot use ``pop`` on a {self.__class__.__name__} instance." + ) + + def update(self, *args, **kwargs): + raise Exception( + f"You cannot use ``update`` on a {self.__class__.__name__} instance." + ) + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception( + f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance." + ) + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception( + f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance." + ) + super().__setitem__(name, value) + + +HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" + + +def extract_commit_hash( + resolved_file: "Optional[str]", commit_hash: "Optional[str]" = None +): + """ + Extracts the commit hash from a resolved filename toward a cache file. + """ + if resolved_file is None or commit_hash is not None: + return commit_hash + resolved_file = str(Path(resolved_file).as_posix()) + search = re.search("snapshots/([^/]+)/", resolved_file) + if search is None: + return None + commit_hash = search.groups()[0] + return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None + + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} + +SESSION_ID = uuid4().hex + +_flax_version = "N/A" + +_jax_version = "N/A" + +_onnxruntime_version = "N/A" + +_torch_version = "N/A" + + +class ConfigMixin: + """ + Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all + methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with + - [`~ConfigMixin.from_config`] + - [`~ConfigMixin.save_config`] + + Class attributes: + - **config_name** (`str`) -- A filename under which the config should stored when calling + [`~ConfigMixin.save_config`] (should be overridden by parent class). + - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be + overridden by subclass). + - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). + - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function + should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by + subclass). + """ + + config_name = None + ignore_for_config = [] + has_compatibles = False + _deprecated_kwargs = [] + + def register_to_config(self, **kwargs): + if self.config_name is None: + raise NotImplementedError( + f"Make sure that {self.__class__} has defined a class name `config_name`" + ) + kwargs.pop("kwargs", None) + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + previous_dict = dict(self._internal_dict) + internal_dict = {**self._internal_dict, **kwargs} + logger.debug(f"Updating config from {previous_dict} to {internal_dict}") + self._internal_dict = FrozenDict(internal_dict) + + def save_config( + self, + save_directory: "Union[str, os.PathLike]", + push_to_hub: "bool" = False, + **kwargs, + ): + if os.path.isfile(save_directory): + raise AssertionError( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + os.makedirs(save_directory, exist_ok=True) + output_config_file = os.path.join(save_directory, self.config_name) + self.to_json_file(output_config_file) + logger.info(f"Configuration saved in {output_config_file}") + + @classmethod + def from_config( + cls, + config: "Union[FrozenDict, Dict[str, Any]]" = None, + return_unused_kwargs=False, + **kwargs, + ): + if "pretrained_model_name_or_path" in kwargs: + config = kwargs.pop("pretrained_model_name_or_path") + if config is None: + raise ValueError( + "Please make sure to provide a config as the first positional argument." + ) + if not isinstance(config, dict): + deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." + if "Scheduler" in cls.__name__: + deprecation_message += f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead. Otherwise, please make sure to pass a configuration dictionary instead. This functionality will be removed in v1.0.0." + elif "Model" in cls.__name__: + deprecation_message += f"If you were trying to load a model, please use {cls}.load_config(...) followed by {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary instead. This functionality will be removed in v1.0.0." + deprecate( + "config-passed-as-path", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + config, kwargs = cls.load_config( + pretrained_model_name_or_path=config, + return_unused_kwargs=True, + **kwargs, + ) + init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + for deprecated_kwarg in cls._deprecated_kwargs: + if deprecated_kwarg in unused_kwargs: + init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) + model = cls(**init_dict) + model.register_to_config(**hidden_dict) + unused_kwargs = {**unused_kwargs, **hidden_dict} + if return_unused_kwargs: + return model, unused_kwargs + else: + return model + + @classmethod + def get_config_dict(cls, *args, **kwargs): + deprecation_message = f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be removed in version v1.0.0" + deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) + return cls.load_config(*args, **kwargs) + + @classmethod + def load_config( + cls, + pretrained_model_name_or_path: "Union[str, os.PathLike]", + return_unused_kwargs=False, + return_commit_hash=False, + **kwargs, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Instantiate a Python class from a config dictionary + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + return_unused_kwargs (`bool`, *optional*, defaults to `False): + Whether unused keyword arguments of the config shall be returned. + return_commit_hash (`bool`, *optional*, defaults to `False): + Whether the commit_hash of the loaded configuration shall be returned. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + _ = kwargs.pop("mirror", None) + subfolder = kwargs.pop("subfolder", None) + user_agent = kwargs.pop("user_agent", {}) + user_agent = {**user_agent, "file_type": "config"} + user_agent = http_user_agent(user_agent) + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if cls.config_name is None: + raise ValueError( + "`self.config_name` is not defined. Note that one should not load a config from `ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" + ) + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile( + os.path.join(pretrained_model_name_or_path, cls.config_name) + ): + config_file = os.path.join( + pretrained_model_name_or_path, cls.config_name + ) + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + ): + config_file = os.path.join( + pretrained_model_name_or_path, subfolder, cls.config_name + ) + else: + raise EnvironmentError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + pass + try: + config_dict = cls._dict_from_json_file(config_file) + commit_hash = extract_commit_hash(config_file) + except (json.JSONDecodeError, UnicodeDecodeError): + raise EnvironmentError( + f"It looks like the config file at '{config_file}' is not a valid JSON file." + ) + if not (return_unused_kwargs or return_commit_hash): + return config_dict + outputs = (config_dict,) + if return_unused_kwargs: + outputs += (kwargs,) + if return_commit_hash: + outputs += (commit_hash,) + return outputs + + @staticmethod + def _get_init_keys(cls): + return set(dict(inspect.signature(cls.__init__).parameters).keys()) + + @classmethod + def extract_init_dict(cls, config_dict, **kwargs): + original_dict = {k: v for k, v in config_dict.items()} + expected_keys = cls._get_init_keys(cls) + expected_keys.remove("self") + if "kwargs" in expected_keys: + expected_keys.remove("kwargs") + if hasattr(cls, "_flax_internal_args"): + for arg in cls._flax_internal_args: + expected_keys.remove(arg) + if len(cls.ignore_for_config) > 0: + expected_keys = expected_keys - set(cls.ignore_for_config) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + if cls.has_compatibles: + compatible_classes = [ + c for c in cls._get_compatibles() if not isinstance(c, DummyObject) + ] + else: + compatible_classes = [] + expected_keys_comp_cls = set() + for c in compatible_classes: + expected_keys_c = cls._get_init_keys(c) + expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) + expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) + config_dict = { + k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls + } + orig_cls_name = config_dict.pop("_class_name", cls.__name__) + if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name): + orig_cls = getattr(diffusers_library, orig_cls_name) + unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys + config_dict = { + k: v + for k, v in config_dict.items() + if k not in unexpected_keys_from_orig + } + config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} + init_dict = {} + for key in expected_keys: + if key in kwargs and key in config_dict: + config_dict[key] = kwargs.pop(key) + if key in kwargs: + init_dict[key] = kwargs.pop(key) + elif key in config_dict: + init_dict[key] = config_dict.pop(key) + if len(config_dict) > 0: + logger.warning( + f"The config attributes {config_dict} were passed to {cls.__name__}, but are not expected and will be ignored. Please verify your {cls.config_name} configuration file." + ) + passed_keys = set(init_dict.keys()) + if len(expected_keys - passed_keys) > 0: + logger.info( + f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." + ) + unused_kwargs = {**config_dict, **kwargs} + hidden_config_dict = { + k: v for k, v in original_dict.items() if k not in init_dict + } + return init_dict, unused_kwargs, hidden_config_dict + + @classmethod + def _dict_from_json_file(cls, json_file: "Union[str, os.PathLike]"): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @property + def config(self) -> Dict[str, Any]: + """ + Returns the config of the class as a frozen dictionary + + Returns: + `Dict[str, Any]`: Config of the class. + """ + return self._internal_dict + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this configuration instance in JSON format. + """ + config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} + config_dict["_class_name"] = self.__class__.__name__ + config_dict["_diffusers_version"] = __version__ + + def to_json_saveable(value): + if isinstance(value, np.ndarray): + value = value.tolist() + elif isinstance(value, PosixPath): + value = str(value) + return value + + config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: "Union[str, os.PathLike]"): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + +class ImagePositionalEmbeddings(nn.Module): + """ + Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the + height and width of the latent space. + + For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 + + For VQ-diffusion: + + Output vector embeddings are used as input for the transformer. + + Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. + + Args: + num_embed (`int`): + Number of embeddings for the latent pixels embeddings. + height (`int`): + Height of the latent image i.e. the number of height embeddings. + width (`int`): + Width of the latent image i.e. the number of width embeddings. + embed_dim (`int`): + Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. + """ + + def __init__(self, num_embed: "int", height: "int", width: "int", embed_dim: "int"): + super().__init__() + self.height = height + self.width = width + self.num_embed = num_embed + self.embed_dim = embed_dim + self.emb = nn.Embedding(self.num_embed, embed_dim) + self.height_emb = nn.Embedding(self.height, embed_dim) + self.width_emb = nn.Embedding(self.width, embed_dim) + + def forward(self, index): + emb = self.emb(index) + height_emb = self.height_emb( + torch.arange(self.height, device=index.device).view(1, self.height) + ) + height_emb = height_emb.unsqueeze(2) + width_emb = self.width_emb( + torch.arange(self.width, device=index.device).view(1, self.width) + ) + width_emb = width_emb.unsqueeze(1) + pos_emb = height_emb + width_emb + pos_emb = pos_emb.view(1, self.height * self.width, -1) + emb = emb + pos_emb[:, : emb.shape[1], :] + return emb + + +CONFIG_NAME = "config.json" + +FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" + +SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" + +WEIGHTS_NAME = "diffusion_pytorch_model.bin" + + +def _add_variant(weights_name: "str", variant: "Optional[str]" = None) -> str: + if variant is not None: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + return weights_name + + +DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] + + +def _get_model_file( + pretrained_model_name_or_path, + *, + weights_name, + subfolder, + cache_dir, + force_download, + proxies, + resume_download, + local_files_only, + use_auth_token, + user_agent, + revision, + commit_hash=None, +): + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path): + return pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): + model_file = os.path.join(pretrained_model_name_or_path, weights_name) + return model_file + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + ): + model_file = os.path.join( + pretrained_model_name_or_path, subfolder, weights_name + ) + return model_file + else: + raise EnvironmentError( + f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." + ) + else: + if ( + revision in DEPRECATED_REVISION_ARGS + and ( + weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME + ) + and version.parse(version.parse(__version__).base_version) + >= version.parse("0.17.0") + ): + try: + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=_add_variant(weights_name, revision), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision or commit_hash, + ) + warnings.warn( + f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", + FutureWarning, + ) + return model_file + except: + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.", + FutureWarning, + ) + + +def _load_state_dict_into_model(model_to_load, state_dict): + state_dict = state_dict.copy() + error_msgs = [] + + def load(module: "torch.nn.Module", prefix=""): + args = state_dict, prefix, {}, True, [], [], error_msgs + module._load_from_state_dict(*args) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model_to_load) + return error_msgs + + +def get_parameter_device(parameter: "torch.nn.Module"): + try: + return next(parameter.parameters()).device + except StopIteration: + + def find_tensor_attributes( + module: "torch.nn.Module", + ) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: "torch.nn.Module"): + try: + return next(parameter.parameters()).dtype + except StopIteration: + + def find_tensor_attributes( + module: "torch.nn.Module", + ) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def is_torch_version(operation: "str", version: "str"): + """ + Args: + Compares the current PyTorch version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + return compare_versions(parse(_torch_version), operation, version) + + +def load_state_dict( + checkpoint_file: "Union[str, os.PathLike]", variant: "Optional[str]" = None +): + """ + Reads a checkpoint file, returning properly formatted errors if they arise. + """ + try: + if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant): + return torch.load(checkpoint_file, map_location="cpu") + else: + return safetensors.torch.load_file(checkpoint_file, device="cpu") + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install git-lfs and run `git lfs install` followed by `git lfs pull` in the folder you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." + ) + + +class ModelMixin(torch.nn.Module): + """ + Base class for all models. + + [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading + and saving models. + + - **config_name** ([`str`]) -- A filename under which the model should be stored when calling + [`~models.ModelMixin.save_pretrained`]. + """ + + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _supports_gradient_checkpointing = False + + def __init__(self): + super().__init__() + + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + return any( + hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing + for m in self.modules() + ) + + def enable_gradient_checkpointing(self): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if not self._supports_gradient_checkpointing: + raise ValueError( + f"{self.__class__.__name__} does not support gradient checkpointing." + ) + self.apply(partial(self._set_gradient_checkpointing, value=True)) + + def disable_gradient_checkpointing(self): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self._supports_gradient_checkpointing: + self.apply(partial(self._set_gradient_checkpointing, value=False)) + + def set_use_memory_efficient_attention_xformers( + self, valid: "bool", attention_op: "Optional[Callable]" = None + ) -> None: + def fn_recursive_set_mem_eff(module: "torch.nn.Module"): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid, attention_op) + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + def enable_xformers_memory_efficient_attention( + self, attention_op: "Optional[Callable]" = None + ): + self.set_use_memory_efficient_attention_xformers(True, attention_op) + + def disable_xformers_memory_efficient_attention(self): + """ + Disable memory efficient attention as implemented in xformers. + """ + self.set_use_memory_efficient_attention_xformers(False) + + def save_pretrained( + self, + save_directory: "Union[str, os.PathLike]", + is_main_process: "bool" = True, + save_function: "Callable" = None, + safe_serialization: "bool" = False, + variant: "Optional[str]" = None, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~models.ModelMixin.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + """ + if safe_serialization and not is_safetensors_available(): + raise ImportError( + "`safe_serialization` requires the `safetensors library: `pip install safetensors`." + ) + if os.path.isfile(save_directory): + logger.error( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + return + os.makedirs(save_directory, exist_ok=True) + model_to_save = self + if is_main_process: + model_to_save.save_config(save_directory) + state_dict = model_to_save.state_dict() + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + if safe_serialization: + safetensors.torch.save_file( + state_dict, + os.path.join(save_directory, weights_name), + metadata={"format": "pt"}, + ) + else: + torch.save(state_dict, os.path.join(save_directory, weights_name)) + logger.info( + f"Model weights saved in {os.path.join(save_directory, weights_name)}" + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: "Optional[Union[str, os.PathLike]]", + **kwargs, + ): + """ + Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_flax`. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + variant = kwargs.pop("variant", None) + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip install accelerate\n```\n." + ) + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set `device_map=None`." + ) + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set `low_cpu_mem_usage=False`." + ) + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + config_path = pretrained_model_name_or_path + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + user_agent=user_agent, + **kwargs, + ) + model_file = None + if from_flax: + model = cls.from_config(config, **unused_kwargs) + else: + if is_safetensors_available(): + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except: + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + if low_cpu_mem_usage: + with accelerate.init_empty_weights(): + model = cls.from_config(config, **unused_kwargs) + if device_map is None: + param_device = "cpu" + state_dict = load_state_dict(model_file, variant=variant) + missing_keys = set(model.state_dict().keys()) - set( + state_dict.keys() + ) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are missing: \n {', '.join(missing_keys)}. \n Please make sure to pass `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize those weights or else make sure your checkpoint file is correct." + ) + for param_name, param in state_dict.items(): + accepts_dtype = "dtype" in set( + inspect.signature( + set_module_tensor_to_device + ).parameters.keys() + ) + if accepts_dtype: + set_module_tensor_to_device( + model, + param_name, + param_device, + value=param, + dtype=torch_dtype, + ) + else: + set_module_tensor_to_device( + model, param_name, param_device, value=param + ) + else: + accelerate.load_checkpoint_and_dispatch( + model, model_file, device_map, dtype=torch_dtype + ) + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + model = cls.from_config(config, **unused_kwargs) + state_dict = load_state_dict(model_file, variant=variant) + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + model.eval() + if output_loading_info: + return model, loading_info + return model + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + model_state_dict = model.state_dict() + loaded_keys = [k for k in state_dict.keys()] + expected_keys = list(model_state_dict.keys()) + original_loaded_keys = loaded_keys + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + model_to_load = model + + def _find_mismatched_keys( + state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape + != model_state_dict[model_key].shape + ): + mismatched_keys.append( + ( + checkpoint_key, + state_dict[checkpoint_key].shape, + model_state_dict[model_key].shape, + ) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + raise RuntimeError( + f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}" + ) + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info( + f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" + ) + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint was trained on, you can already use {model.__class__.__name__} for predictions without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} and are newly initialized because the shapes did not match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + @property + def device(self) -> device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def num_parameters( + self, only_trainable: "bool" = False, exclude_embeddings: "bool" = False + ) -> int: + """ + Get number of (optionally, trainable or non-embeddings) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embeddings parameters + + Returns: + `int`: The number of parameters. + """ + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.named_modules() + if isinstance(module_type, torch.nn.Embedding) + ] + non_embedding_parameters = [ + parameter + for name, parameter in self.named_parameters() + if name not in embedding_param_names + ] + return sum( + p.numel() + for p in non_embedding_parameters + if p.requires_grad or not only_trainable + ) + else: + return sum( + p.numel() + for p in self.parameters() + if p.requires_grad or not only_trainable + ) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000 ** omega + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + emb_sin = np.sin(out) + emb_cos = np.cos(out) + emb = np.concatenate([emb_sin, emb_cos], axis=1) + return emb + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + emb = np.concatenate([emb_h, emb_w], axis=1) + return emb + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate( + [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + ): + super().__init__() + num_patches = height // patch_size * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + self.proj = nn.Conv2d( + in_channels, + embed_dim, + kernel_size=(patch_size, patch_size), + stride=patch_size, + bias=bias, + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-06) + else: + self.norm = None + pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches ** 0.5)) + self.register_buffer( + "pos_embed", + torch.from_numpy(pos_embed).float().unsqueeze(0), + persistent=False, + ) + + def forward(self, latent): + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) + if self.layer_norm: + latent = self.norm(latent) + return latent + self.pos_embed + + +class BaseOutput(OrderedDict): + """ + Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a + tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular + python dictionary. + + + + You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple + before. + + + """ + + def __post_init__(self): + class_fields = fields(self) + if not len(class_fields): + raise ValueError(f"{self.__class__.__name__} has no fields.") + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all( + getattr(self, field.name) is None for field in class_fields[1:] + ) + if other_fields_are_none and isinstance(first_field, dict): + for key, value in first_field.items(): + self[key] = value + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise Exception( + f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." + ) + + def setdefault(self, *args, **kwargs): + raise Exception( + f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." + ) + + def pop(self, *args, **kwargs): + raise Exception( + f"You cannot use ``pop`` on a {self.__class__.__name__} instance." + ) + + def update(self, *args, **kwargs): + raise Exception( + f"You cannot use ``update`` on a {self.__class__.__name__} instance." + ) + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = {k: v for k, v in self.items()} + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + super().__setitem__(key, value) + super().__setattr__(key, value) + + def to_tuple(self) -> Tuple[Any]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(self[k] for k in self.keys()) + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions + for the unnoised latent pixels. + """ + + sample: "torch.FloatTensor" + + +def register_to_config(init): + """ + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self, *args, **kwargs): + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does not inherit from `ConfigMixin`." + ) + ignore = getattr(self, "ignore_for_config", []) + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default + for i, (name, p) in enumerate(signature.parameters.items()) + if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + new_kwargs = {**config_init_kwargs, **new_kwargs} + getattr(self, "register_to_config")(**new_kwargs) + init(self, *args, **init_kwargs) + + return inner_init + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual + embeddings) inputs. + + When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard + transformer action. Finally, reshape to image. + + When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional + embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict + classes of unnoised image. + + Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised + image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: "int" = 16, + attention_head_dim: "int" = 88, + in_channels: "Optional[int]" = None, + out_channels: "Optional[int]" = None, + num_layers: "int" = 1, + dropout: "float" = 0.0, + norm_num_groups: "int" = 32, + cross_attention_dim: "Optional[int]" = None, + attention_bias: "bool" = False, + sample_size: "Optional[int]" = None, + num_vector_embeds: "Optional[int]" = None, + patch_size: "Optional[int]" = None, + activation_fn: "str" = "geglu", + num_embeds_ada_norm: "Optional[int]" = None, + use_linear_projection: "bool" = False, + only_cross_attention: "bool" = False, + upcast_attention: "bool" = False, + norm_type: "str" = "layer_norm", + norm_elementwise_affine: "bool" = True, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.is_input_continuous = in_channels is not None and patch_size is None + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config. Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for the `transformer/config.json` file" + deprecate( + "norm_type!=num_embeds_ada_norm", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + norm_type = "ada_norm" + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif ( + not self.is_input_continuous + and not self.is_input_vectorized + and not self.is_input_patches + ): + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size: {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + if self.is_input_continuous: + self.in_channels = in_channels + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, + num_channels=in_channels, + eps=1e-06, + affine=True, + ) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + elif self.is_input_vectorized: + assert ( + sample_size is not None + ), "Transformer2DModel over discrete input must provide sample_size" + assert ( + num_vector_embeds is not None + ), "Transformer2DModel over discrete input must provide num_embed" + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, + embed_dim=inner_dim, + height=self.height, + width=self.width, + ) + elif self.is_input_patches: + assert ( + sample_size is not None + ), "Transformer2DModel over patched input must provide sample_size" + self.height = sample_size + self.width = sample_size + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + ) + for d in range(num_layers) + ] + ) + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + if use_linear_projection: + self.proj_out = nn.Linear(inner_dim, in_channels) + else: + self.proj_out = nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0 + ) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-06) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear( + inner_dim, patch_size * patch_size * self.out_channels + ) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + return_dict: "bool" = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + logits = logits.permute(0, 2, 1) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = ( + self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + ) + hidden_states = self.proj_out_2(hidden_states) + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=( + -1, + height, + width, + self.patch_size, + self.patch_size, + self.out_channels, + ) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=( + -1, + self.out_channels, + height * self.patch_size, + width * self.patch_size, + ) + ) + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + +class DualTransformer2DModel(nn.Module): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: "int" = 16, + attention_head_dim: "int" = 88, + in_channels: "Optional[int]" = None, + num_layers: "int" = 1, + dropout: "float" = 0.0, + norm_num_groups: "int" = 32, + cross_attention_dim: "Optional[int]" = None, + attention_bias: "bool" = False, + sample_size: "Optional[int]" = None, + num_vector_embeds: "Optional[int]" = None, + activation_fn: "str" = "geglu", + num_embeds_ada_norm: "Optional[int]" = None, + ): + super().__init__() + self.transformers = nn.ModuleList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + self.mix_ratio = 0.5 + self.condition_lengths = [77, 257] + self.transformer_index_for_condition = [1, 0] + + def forward( + self, + hidden_states, + encoder_hidden_states, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + return_dict: "bool" = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + attention_mask (`torch.FloatTensor`, *optional*): + Optional attention mask to be applied in CrossAttention + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + input_states = hidden_states + encoded_states = [] + tokens_start = 0 + for i in range(2): + condition_state = encoder_hidden_states[ + :, tokens_start : tokens_start + self.condition_lengths[i] + ] + transformer_index = self.transformer_index_for_condition[i] + encoded_state = self.transformers[transformer_index]( + input_states, + encoder_hidden_states=condition_state, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + encoded_states.append(encoded_state - input_states) + tokens_start += self.condition_lengths[i] + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * ( + 1 - self.mix_ratio + ) + output_states = output_states + input_states + if not return_dict: + return (output_states,) + return Transformer2DModelOutput(sample=output_states) + + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__( + self, + embedding_size: "int" = 256, + scale: "float" = 1.0, + set_W_to_weight=True, + log=True, + flip_sin_to_cos=False, + ): + super().__init__() + self.weight = nn.Parameter( + torch.randn(embedding_size) * scale, requires_grad=False + ) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + if set_W_to_weight: + self.W = nn.Parameter( + torch.randn(embedding_size) * scale, requires_grad=False + ) + self.weight = self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out + + +@dataclass +class PriorTransformerOutput(BaseOutput): + """ + Args: + predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + The predicted CLIP image embedding conditioned on the CLIP text embedding input. + """ + + predicted_image_embedding: "torch.FloatTensor" + + +class PriorTransformer(ModelMixin, ConfigMixin): + """ + The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the + transformer predicts the image embeddings through a denoising diffusion process. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + For more details, see the original paper: https://arxiv.org/abs/2204.06125 + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. + embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP + image embeddings and text embeddings are both the same dimension. + num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the + length of the prompt after it has been tokenized. + additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the + projected hidden_states. The actual length of the used hidden_states is `num_embeddings + + additional_embeddings`. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + + """ + + @register_to_config + def __init__( + self, + num_attention_heads: "int" = 32, + attention_head_dim: "int" = 64, + num_layers: "int" = 20, + embedding_dim: "int" = 768, + num_embeddings=77, + additional_embeddings=4, + dropout: "float" = 0.0, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.additional_embeddings = additional_embeddings + self.time_proj = Timesteps(inner_dim, True, 0) + self.time_embedding = TimestepEmbedding(inner_dim, inner_dim) + self.proj_in = nn.Linear(embedding_dim, inner_dim) + self.embedding_proj = nn.Linear(embedding_dim, inner_dim) + self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) + self.positional_embedding = nn.Parameter( + torch.zeros(1, num_embeddings + additional_embeddings, inner_dim) + ) + self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + activation_fn="gelu", + attention_bias=True, + ) + for d in range(num_layers) + ] + ) + self.norm_out = nn.LayerNorm(inner_dim) + self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim) + causal_attention_mask = torch.full( + [ + num_embeddings + additional_embeddings, + num_embeddings + additional_embeddings, + ], + -10000.0, + ) + causal_attention_mask.triu_(1) + causal_attention_mask = causal_attention_mask[None, ...] + self.register_buffer( + "causal_attention_mask", causal_attention_mask, persistent=False + ) + self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim)) + self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim)) + + def forward( + self, + hidden_states, + timestep: "Union[torch.Tensor, float, int]", + proj_embedding: "torch.FloatTensor", + encoder_hidden_states: "torch.FloatTensor", + attention_mask: "Optional[torch.BoolTensor]" = None, + return_dict: "bool" = True, + ): + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + x_t, the currently predicted image embeddings. + timestep (`torch.long`): + Current denoising step. + proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + Projected embedding vector the denoising process is conditioned on. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): + Hidden states of the text embeddings the denoising process is conditioned on. + attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): + Text mask for the text embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain + tuple. + + Returns: + [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: + [`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + batch_size = hidden_states.shape[0] + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor( + [timesteps], dtype=torch.long, device=hidden_states.device + ) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None] + timesteps = timesteps * torch.ones( + batch_size, dtype=timesteps.dtype, device=timesteps.device + ) + timesteps_projected = self.time_proj(timesteps) + timesteps_projected = timesteps_projected + time_embeddings = self.time_embedding(timesteps_projected) + proj_embeddings = self.embedding_proj(proj_embedding) + encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) + hidden_states = self.proj_in(hidden_states) + prd_embedding = self.prd_embedding.expand(batch_size, -1, -1) + positional_embeddings = self.positional_embedding + hidden_states = torch.cat( + [ + encoder_hidden_states, + proj_embeddings[:, None, :], + time_embeddings[:, None, :], + hidden_states[:, None, :], + prd_embedding, + ], + dim=1, + ) + hidden_states = hidden_states + positional_embeddings + if attention_mask is not None: + attention_mask = (1 - attention_mask) * -10000.0 + attention_mask = F.pad( + attention_mask, (0, self.additional_embeddings), value=0.0 + ) + attention_mask = attention_mask[:, None, :] + self.causal_attention_mask + attention_mask = attention_mask.repeat_interleave( + self.config.num_attention_heads, dim=0 + ) + for block in self.transformer_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask) + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states[:, -1] + predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) + if not return_dict: + return (predicted_image_embedding,) + return PriorTransformerOutput( + predicted_image_embedding=predicted_image_embedding + ) + + def post_process_latents(self, prior_latents): + prior_latents = prior_latents * self.clip_std + self.clip_mean + return prior_latents + + +class Upsample1D(nn.Module): + """ + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: + """ + + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=False, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample1D(nn.Module): + """ + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: + """ + + def __init__( + self, channels, use_conv=False, out_channels=None, padding=1, name="conv" + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + if use_conv: + self.conv = nn.Conv1d( + self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.conv(x) + + +class Upsample2D(nn.Module): + """ + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: + """ + + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=False, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + conv = None + if use_conv_transpose: + conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(hidden_states) + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + if output_size is None: + hidden_states = F.interpolate( + hidden_states, scale_factor=2.0, mode="nearest" + ) + else: + hidden_states = F.interpolate( + hidden_states, size=output_size, mode="nearest" + ) + if dtype == torch.bfloat16: + hidden_states = hidden_states + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + return hidden_states + + +class Downsample2D(nn.Module): + """ + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: + """ + + def __init__( + self, channels, use_conv=False, out_channels=None, padding=1, name="conv" + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + if use_conv: + conv = nn.Conv2d( + self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + pad = 0, 1, 0, 1 + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + return hidden_states + + +def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + _, channel, in_h, in_w = tensor.shape + tensor = tensor.reshape(-1, in_h, in_w, 1) + _, in_h, in_w, minor = tensor.shape + kernel_h, kernel_w = kernel.shape + out = tensor.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + return out.view(-1, channel, out_h, out_w) + + +class FirUpsample2D(nn.Module): + def __init__( + self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1) + ): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d( + channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.use_conv = use_conv + self.fir_kernel = fir_kernel + self.out_channels = out_channels + + def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): + """Fused `upsample_2d()` followed by `Conv2d()`. + + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same + datatype as `hidden_states`. + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + kernel = kernel * (gain * factor ** 2) + if self.use_conv: + convH = weight.shape[2] + convW = weight.shape[3] + inC = weight.shape[1] + pad_value = kernel.shape[0] - factor - (convW - 1) + stride = factor, factor + output_shape = (hidden_states.shape[2] - 1) * factor + convH, ( + hidden_states.shape[3] - 1 + ) * factor + convW + output_padding = ( + output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + num_groups = hidden_states.shape[1] // inC + weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) + weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) + weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) + inverse_conv = F.conv_transpose2d( + hidden_states, + weight, + stride=stride, + output_padding=output_padding, + padding=0, + ) + output = upfirdn2d_native( + inverse_conv, + torch.tensor(kernel, device=inverse_conv.device), + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), + ) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + return output + + def forward(self, hidden_states): + if self.use_conv: + height = self._upsample_2d( + hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel + ) + height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) + return height + + +class FirDownsample2D(nn.Module): + def __init__( + self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1) + ): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d( + channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.fir_kernel = fir_kernel + self.use_conv = use_conv + self.out_channels = out_channels + + def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): + """Fused `Conv2d()` followed by `downsample_2d()`. + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight: + Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be + performed by `inChannels = x.shape[0] // numGroups`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * + factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and + same datatype as `x`. + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + kernel = kernel * gain + if self.use_conv: + _, _, convH, convW = weight.shape + pad_value = kernel.shape[0] - factor + (convW - 1) + stride_value = [factor, factor] + upfirdn_input = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + pad=((pad_value + 1) // 2, pad_value // 2), + ) + output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + return output + + def forward(self, hidden_states): + if self.use_conv: + downsample_input = self._downsample_2d( + hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel + ) + hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + hidden_states = self._downsample_2d( + hidden_states, kernel=self.fir_kernel, factor=2 + ) + return hidden_states + + +class KDownsample2D(nn.Module): + def __init__(self, pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) + + def forward(self, x): + x = F.pad(x, (self.pad,) * 4, self.pad_mode) + weight = x.new_zeros( + [x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]] + ) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel + return F.conv2d(x, weight, stride=2) + + +class KUpsample2D(nn.Module): + def __init__(self, pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) + + def forward(self, x): + x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) + weight = x.new_zeros( + [x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]] + ) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel + return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) + + +def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): + """Downsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H // factor, W // factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + kernel = kernel * gain + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, kernel, down=factor, pad=((pad_value + 1) // 2, pad_value // 2) + ) + return output + + +def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): + """Upsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is + a: multiple of the upsampling factor. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + kernel = kernel * (gain * factor ** 2) + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + kernel, + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + return output + + +class ResnetBlock2D(nn.Module): + """ + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. + By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or + "ada_group" for a stronger conditioning with scale and shift. + kernal (`torch.FloatTensor`, optional, default to None): FIR filter, see + [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. + output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. + If None, same as `out_channels`. + """ + + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-06, + non_linearity="swish", + time_embedding_norm="default", + kernel=None, + output_scale_factor=1.0, + use_in_shortcut=None, + up=False, + down=False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm + if groups_out is None: + groups_out = groups + if self.time_embedding_norm == "ada_group": + self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) + else: + self.norm1 = torch.nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) + elif self.time_embedding_norm == "ada_group": + self.time_emb_proj = None + else: + raise ValueError( + f"unknown time_embedding_norm : {self.time_embedding_norm} " + ) + else: + self.time_emb_proj = None + if self.time_embedding_norm == "ada_group": + self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) + else: + self.norm2 = torch.nn.GroupNorm( + num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True + ) + self.dropout = torch.nn.Dropout(dropout) + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.conv2 = torch.nn.Conv2d( + out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1 + ) + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + elif non_linearity == "gelu": + self.nonlinearity = nn.GELU() + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = 1, 3, 3, 1 + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = 1, 3, 3, 1 + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D( + in_channels, use_conv=False, padding=1, name="op" + ) + self.use_in_shortcut = ( + self.in_channels != conv_2d_out_channels + if use_in_shortcut is None + else use_in_shortcut + ) + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + if self.time_embedding_norm == "ada_group": + hidden_states = self.norm1(hidden_states, temb) + else: + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + if self.upsample is not None: + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + hidden_states = self.conv1(hidden_states) + if self.time_emb_proj is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + if self.time_embedding_norm == "ada_group": + hidden_states = self.norm2(hidden_states, temb) + else: + hidden_states = self.norm2(hidden_states) + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) + + +def rearrange_dims(tensor): + if len(tensor.shape) == 2: + return tensor[:, :, None] + if len(tensor.shape) == 3: + return tensor[:, :, None, :] + elif len(tensor.shape) == 4: + return tensor[:, :, 0, :] + else: + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + + +class Conv1dBlock(nn.Module): + """ + Conv1d --> GroupNorm --> Mish + """ + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + self.conv1d = nn.Conv1d( + inp_channels, out_channels, kernel_size, padding=kernel_size // 2 + ) + self.group_norm = nn.GroupNorm(n_groups, out_channels) + self.mish = nn.Mish() + + def forward(self, x): + x = self.conv1d(x) + x = rearrange_dims(x) + x = self.group_norm(x) + x = rearrange_dims(x) + x = self.mish(x) + return x + + +class ResidualTemporalBlock1D(nn.Module): + def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): + super().__init__() + self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) + self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) + self.time_emb_act = nn.Mish() + self.time_emb = nn.Linear(embed_dim, out_channels) + self.residual_conv = ( + nn.Conv1d(inp_channels, out_channels, 1) + if inp_channels != out_channels + else nn.Identity() + ) + + def forward(self, x, t): + """ + Args: + x : [ batch_size x inp_channels x horizon ] + t : [ batch_size x embed_dim ] + + returns: + out : [ batch_size x out_channels x horizon ] + """ + t = self.time_emb_act(t) + t = self.time_emb(t) + out = self.conv_in(x) + rearrange_dims(t) + out = self.conv_out(out) + return out + self.residual_conv(x) + + +@dataclass +class UNet1DOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): + Hidden states output. Output of last layer of model. + """ + + sample: "torch.FloatTensor" + + +class LinearMultiDim(nn.Linear): + def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): + in_features = ( + [in_features, second_dim, 1] + if isinstance(in_features, int) + else list(in_features) + ) + if out_features is None: + out_features = in_features + out_features = ( + [out_features, second_dim, 1] + if isinstance(out_features, int) + else list(out_features) + ) + self.in_features_multidim = in_features + self.out_features_multidim = out_features + super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) + + def forward(self, input_tensor, *args, **kwargs): + shape = input_tensor.shape + n_dim = len(self.in_features_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) + output_tensor = super().forward(input_tensor) + output_tensor = output_tensor.view( + *shape[0:-n_dim], *self.out_features_multidim + ) + return output_tensor + + +class ResnetBlockFlat(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-06, + time_embedding_norm="default", + use_in_shortcut=None, + second_dim=4, + **kwargs, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + in_channels = ( + [in_channels, second_dim, 1] + if isinstance(in_channels, int) + else list(in_channels) + ) + self.in_channels_prod = np.array(in_channels).prod() + self.channels_multidim = in_channels + if out_channels is not None: + out_channels = ( + [out_channels, second_dim, 1] + if isinstance(out_channels, int) + else list(out_channels) + ) + out_channels_prod = np.array(out_channels).prod() + self.out_channels_multidim = out_channels + else: + out_channels_prod = self.in_channels_prod + self.out_channels_multidim = self.channels_multidim + self.time_embedding_norm = time_embedding_norm + if groups_out is None: + groups_out = groups + self.norm1 = torch.nn.GroupNorm( + num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True + ) + self.conv1 = torch.nn.Conv2d( + self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0 + ) + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) + else: + self.time_emb_proj = None + self.norm2 = torch.nn.GroupNorm( + num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True + ) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels_prod, out_channels_prod, kernel_size=1, padding=0 + ) + self.nonlinearity = nn.SiLU() + self.use_in_shortcut = ( + self.in_channels_prod != out_channels_prod + if use_in_shortcut is None + else use_in_shortcut + ) + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + self.in_channels_prod, + out_channels_prod, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, input_tensor, temb): + shape = input_tensor.shape + n_dim = len(self.channels_multidim) + input_tensor = input_tensor.reshape( + *shape[0:-n_dim], self.in_channels_prod, 1, 1 + ) + input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) + hidden_states = input_tensor + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + output_tensor = input_tensor + hidden_states + output_tensor = output_tensor.view(*shape[0:-n_dim], -1) + output_tensor = output_tensor.view( + *shape[0:-n_dim], *self.out_channels_multidim + ) + return output_tensor + + +class CrossAttnDownBlockFlat(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + output_states = () + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states += (hidden_states,) + return hidden_states, output_states + + +class DownBlockFlat(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states += (hidden_states,) + return hidden_states, output_states + + +# def get_down_block( +# down_block_type, +# num_layers, +# in_channels, +# out_channels, +# temb_channels, +# add_downsample, +# resnet_eps, +# resnet_act_fn, +# attn_num_head_channels, +# resnet_groups=None, +# cross_attention_dim=None, +# downsample_padding=None, +# dual_cross_attention=False, +# use_linear_projection=False, +# only_cross_attention=False, +# upcast_attention=False, +# resnet_time_scale_shift="default", +# ): +# down_block_type = ( +# down_block_type[7:] +# if down_block_type.startswith("UNetRes") +# else down_block_type +# ) +# if down_block_type == "DownBlockFlat": +# return DownBlockFlat( +# num_layers=num_layers, +# in_channels=in_channels, +# out_channels=out_channels, +# temb_channels=temb_channels, +# add_downsample=add_downsample, +# resnet_eps=resnet_eps, +# resnet_act_fn=resnet_act_fn, +# resnet_groups=resnet_groups, +# downsample_padding=downsample_padding, +# resnet_time_scale_shift=resnet_time_scale_shift, +# ) +# elif down_block_type == "CrossAttnDownBlockFlat": +# if cross_attention_dim is None: +# raise ValueError( +# "cross_attention_dim must be specified for CrossAttnDownBlockFlat" +# ) +# return CrossAttnDownBlockFlat( +# num_layers=num_layers, +# in_channels=in_channels, +# out_channels=out_channels, +# temb_channels=temb_channels, +# add_downsample=add_downsample, +# resnet_eps=resnet_eps, +# resnet_act_fn=resnet_act_fn, +# resnet_groups=resnet_groups, +# downsample_padding=downsample_padding, +# cross_attention_dim=cross_attention_dim, +# attn_num_head_channels=attn_num_head_channels, +# dual_cross_attention=dual_cross_attention, +# use_linear_projection=use_linear_projection, +# only_cross_attention=only_cross_attention, +# resnet_time_scale_shift=resnet_time_scale_shift, +# ) +# raise ValueError(f"{down_block_type} is not supported.") + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + down_block_type = ( + down_block_type[7:] + if down_block_type.startswith("UNetRes") + else down_block_type + ) + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownBlock2D": + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnDownBlock2D" + ) + return CrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D" + ) + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +class MidResTemporalBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dim, + num_layers: "int" = 1, + add_downsample: "bool" = False, + add_upsample: "bool" = False, + non_linearity=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_downsample = add_downsample + resnets = [ + ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim) + ] + for _ in range(num_layers): + resnets.append( + ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim) + ) + self.resnets = nn.ModuleList(resnets) + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + self.upsample = None + if add_upsample: + self.upsample = Downsample1D(out_channels, use_conv=True) + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True) + if self.upsample and self.downsample: + raise ValueError("Block cannot downsample and upsample") + + def forward(self, hidden_states, temb): + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + if self.upsample: + hidden_states = self.upsample(hidden_states) + if self.downsample: + self.downsample = self.downsample(hidden_states) + return hidden_states + + +_kernels = { + "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], + "cubic": [ + -0.01171875, + -0.03515625, + 0.11328125, + 0.43359375, + 0.43359375, + 0.11328125, + -0.03515625, + -0.01171875, + ], + "lanczos3": [ + 0.003689131001010537, + 0.015056144446134567, + -0.03399861603975296, + -0.066637322306633, + 0.13550527393817902, + 0.44638532400131226, + 0.44638532400131226, + 0.13550527393817902, + -0.066637322306633, + -0.03399861603975296, + 0.015056144446134567, + 0.003689131001010537, + ], +} + + +class Downsample1d(nn.Module): + def __init__(self, kernel="linear", pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states): + hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) + weight = hidden_states.new_zeros( + [hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]] + ) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + weight[indices, indices] = self.kernel + return F.conv1d(hidden_states, weight, stride=2) + + +class ResConvBlock(nn.Module): + def __init__(self, in_channels, mid_channels, out_channels, is_last=False): + super().__init__() + self.is_last = is_last + self.has_conv_skip = in_channels != out_channels + if self.has_conv_skip: + self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) + self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) + self.group_norm_1 = nn.GroupNorm(1, mid_channels) + self.gelu_1 = nn.GELU() + self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) + if not self.is_last: + self.group_norm_2 = nn.GroupNorm(1, out_channels) + self.gelu_2 = nn.GELU() + + def forward(self, hidden_states): + residual = ( + self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states + ) + hidden_states = self.conv_1(hidden_states) + hidden_states = self.group_norm_1(hidden_states) + hidden_states = self.gelu_1(hidden_states) + hidden_states = self.conv_2(hidden_states) + if not self.is_last: + hidden_states = self.group_norm_2(hidden_states) + hidden_states = self.gelu_2(hidden_states) + output = hidden_states + residual + return output + + +class SelfAttention1d(nn.Module): + def __init__(self, in_channels, n_head=1, dropout_rate=0.0): + super().__init__() + self.channels = in_channels + self.group_norm = nn.GroupNorm(1, num_channels=in_channels) + self.num_heads = n_head + self.query = nn.Linear(self.channels, self.channels) + self.key = nn.Linear(self.channels, self.channels) + self.value = nn.Linear(self.channels, self.channels) + self.proj_attn = nn.Linear(self.channels, self.channels, 1) + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + def transpose_for_scores(self, projection: "torch.Tensor") -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states): + residual = hidden_states + batch, channel_dim, seq = hidden_states.shape + hidden_states = self.group_norm(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) + attention_scores = torch.matmul( + query_states * scale, key_states.transpose(-1, -2) * scale + ) + attention_probs = torch.softmax(attention_scores, dim=-1) + hidden_states = torch.matmul(attention_probs, value_states) + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.dropout(hidden_states) + output = hidden_states + residual + return output + + +class Upsample1d(nn.Module): + def __init__(self, kernel="linear", pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states, temb=None): + hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = hidden_states.new_zeros( + [hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]] + ) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + weight[indices, indices] = self.kernel + return F.conv_transpose1d( + hidden_states, weight, stride=2, padding=self.pad * 2 + 1 + ) + + +class UNetMidBlock1D(nn.Module): + def __init__(self, mid_channels, in_channels, out_channels=None): + super().__init__() + out_channels = in_channels if out_channels is None else out_channels + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + self.up = Upsample1d(kernel="cubic") + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.down(hidden_states) + for attn, resnet in zip(self.attentions, self.resnets): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + hidden_states = self.up(hidden_states) + return hidden_states + + +class ValueFunctionMidBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, embed_dim): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + self.res1 = ResidualTemporalBlock1D( + in_channels, in_channels // 2, embed_dim=embed_dim + ) + self.down1 = Downsample1D(out_channels // 2, use_conv=True) + self.res2 = ResidualTemporalBlock1D( + in_channels // 2, in_channels // 4, embed_dim=embed_dim + ) + self.down2 = Downsample1D(out_channels // 4, use_conv=True) + + def forward(self, x, temb=None): + x = self.res1(x, temb) + x = self.down1(x) + x = self.res2(x, temb) + x = self.down2(x) + return x + + +def get_mid_block( + mid_block_type, + num_layers, + in_channels, + mid_channels, + out_channels, + embed_dim, + add_downsample, +): + if mid_block_type == "MidResTemporalBlock1D": + return MidResTemporalBlock1D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + embed_dim=embed_dim, + add_downsample=add_downsample, + ) + elif mid_block_type == "ValueFunctionMidBlock1D": + return ValueFunctionMidBlock1D( + in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim + ) + elif mid_block_type == "UNetMidBlock1D": + return UNetMidBlock1D( + in_channels=in_channels, + mid_channels=mid_channels, + out_channels=out_channels, + ) + raise ValueError(f"{mid_block_type} does not exist.") + + +class OutConv1DBlock(nn.Module): + def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): + super().__init__() + self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) + self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) + if act_fn == "silu": + self.final_conv1d_act = nn.SiLU() + if act_fn == "mish": + self.final_conv1d_act = nn.Mish() + self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) + + def forward(self, hidden_states, temb=None): + hidden_states = self.final_conv1d_1(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_gn(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_act(hidden_states) + hidden_states = self.final_conv1d_2(hidden_states) + return hidden_states + + +class OutValueFunctionBlock(nn.Module): + def __init__(self, fc_dim, embed_dim): + super().__init__() + self.final_block = nn.ModuleList( + [ + nn.Linear(fc_dim + embed_dim, fc_dim // 2), + nn.Mish(), + nn.Linear(fc_dim // 2, 1), + ] + ) + + def forward(self, hidden_states, temb): + hidden_states = hidden_states.view(hidden_states.shape[0], -1) + hidden_states = torch.cat((hidden_states, temb), dim=-1) + for layer in self.final_block: + hidden_states = layer(hidden_states) + return hidden_states + + +def get_out_block( + *, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim +): + if out_block_type == "OutConv1DBlock": + return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) + elif out_block_type == "ValueFunction": + return OutValueFunctionBlock(fc_dim, embed_dim) + return None + + +class CrossAttnUpBlockFlat(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + prev_output_channel: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + cross_attention_kwargs=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + return hidden_states + + +class UpBlockFlat(nn.Module): + def __init__( + self, + in_channels: "int", + prev_output_channel: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None + ): + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + return hidden_states + + +# def get_up_block( +# up_block_type, +# num_layers, +# in_channels, +# out_channels, +# prev_output_channel, +# temb_channels, +# add_upsample, +# resnet_eps, +# resnet_act_fn, +# attn_num_head_channels, +# resnet_groups=None, +# cross_attention_dim=None, +# dual_cross_attention=False, +# use_linear_projection=False, +# only_cross_attention=False, +# upcast_attention=False, +# resnet_time_scale_shift="default", +# ): +# up_block_type = ( +# up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type +# ) +# if up_block_type == "UpBlockFlat": +# return UpBlockFlat( +# num_layers=num_layers, +# in_channels=in_channels, +# out_channels=out_channels, +# prev_output_channel=prev_output_channel, +# temb_channels=temb_channels, +# add_upsample=add_upsample, +# resnet_eps=resnet_eps, +# resnet_act_fn=resnet_act_fn, +# resnet_groups=resnet_groups, +# resnet_time_scale_shift=resnet_time_scale_shift, +# ) +# elif up_block_type == "CrossAttnUpBlockFlat": +# if cross_attention_dim is None: +# raise ValueError( +# "cross_attention_dim must be specified for CrossAttnUpBlockFlat" +# ) +# return CrossAttnUpBlockFlat( +# num_layers=num_layers, +# in_channels=in_channels, +# out_channels=out_channels, +# prev_output_channel=prev_output_channel, +# temb_channels=temb_channels, +# add_upsample=add_upsample, +# resnet_eps=resnet_eps, +# resnet_act_fn=resnet_act_fn, +# resnet_groups=resnet_groups, +# cross_attention_dim=cross_attention_dim, +# attn_num_head_channels=attn_num_head_channels, +# dual_cross_attention=dual_cross_attention, +# use_linear_projection=use_linear_projection, +# only_cross_attention=only_cross_attention, +# resnet_time_scale_shift=resnet_time_scale_shift, +# ) +# raise ValueError(f"{up_block_type} is not supported.") +# + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + up_block_type = ( + up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + ) + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnUpBlock2D" + ) + return CrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D" + ) + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnUpBlock2D": + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class UNet1DModel(ModelMixin, ConfigMixin): + """ + UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. + in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. + time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. + freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to : + obj:`False`): Whether to flip sin to cos for fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(32, 32, 64)`): Tuple of block output channels. + mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. + out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet. + act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks. + norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. + layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. + downsample_each_block (`int`, *optional*, defaults to False: + experimental feature for using a UNet without upsampling. + """ + + @register_to_config + def __init__( + self, + sample_size: "int" = 65536, + sample_rate: "Optional[int]" = None, + in_channels: "int" = 2, + out_channels: "int" = 2, + extra_in_channels: "int" = 0, + time_embedding_type: "str" = "fourier", + flip_sin_to_cos: "bool" = True, + use_timestep_embedding: "bool" = False, + freq_shift: "float" = 0.0, + down_block_types: "Tuple[str]" = ( + "DownBlock1DNoSkip", + "DownBlock1D", + "AttnDownBlock1D", + ), + up_block_types: "Tuple[str]" = ( + "AttnUpBlock1D", + "UpBlock1D", + "UpBlock1DNoSkip", + ), + mid_block_type: "Tuple[str]" = "UNetMidBlock1D", + out_block_type: "str" = None, + block_out_channels: "Tuple[int]" = (32, 32, 64), + act_fn: "str" = None, + norm_num_groups: "int" = 8, + layers_per_block: "int" = 1, + downsample_each_block: "bool" = False, + ): + super().__init__() + self.sample_size = sample_size + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection( + embedding_size=8, + set_W_to_weight=False, + log=False, + flip_sin_to_cos=flip_sin_to_cos, + ) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps( + block_out_channels[0], + flip_sin_to_cos=flip_sin_to_cos, + downscale_freq_shift=freq_shift, + ) + timestep_input_dim = block_out_channels[0] + if use_timestep_embedding: + time_embed_dim = block_out_channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=timestep_input_dim, + time_embed_dim=time_embed_dim, + act_fn=act_fn, + out_dim=block_out_channels[0], + ) + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + self.out_block = None + output_channel = in_channels + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + if i == 0: + input_channel += extra_in_channels + is_final_block = i == len(block_out_channels) - 1 + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_downsample=not is_final_block or downsample_each_block, + ) + self.down_blocks.append(down_block) + self.mid_block = get_mid_block( + mid_block_type, + in_channels=block_out_channels[-1], + mid_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + embed_dim=block_out_channels[0], + num_layers=layers_per_block, + add_downsample=downsample_each_block, + ) + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + if out_block_type is None: + final_upsample_channels = out_channels + else: + final_upsample_channels = block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = ( + reversed_block_out_channels[i + 1] + if i < len(up_block_types) - 1 + else final_upsample_channels + ) + is_final_block = i == len(block_out_channels) - 1 + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block, + in_channels=prev_output_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_upsample=not is_final_block, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + num_groups_out = ( + norm_num_groups + if norm_num_groups is not None + else min(block_out_channels[0] // 4, 32) + ) + self.out_block = get_out_block( + out_block_type=out_block_type, + num_groups_out=num_groups_out, + embed_dim=block_out_channels[0], + out_channels=out_channels, + act_fn=act_fn, + fc_dim=block_out_channels[-1] // 4, + ) + + def forward( + self, + sample: "torch.FloatTensor", + timestep: "Union[torch.Tensor, float, int]", + return_dict: "bool" = True, + ) -> Union[UNet1DOutput, Tuple]: + """ + Args: + sample (`torch.FloatTensor`): `(batch_size, num_channels, sample_size)` noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor( + [timesteps], dtype=torch.long, device=sample.device + ) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None] + timestep_embed = self.time_proj(timesteps) + if self.config.use_timestep_embedding: + timestep_embed = self.time_mlp(timestep_embed) + else: + timestep_embed = timestep_embed[..., None] + timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]) + timestep_embed = timestep_embed.broadcast_to( + sample.shape[:1] + timestep_embed.shape[1:] + ) + down_block_res_samples = () + for downsample_block in self.down_blocks: + sample, res_samples = downsample_block( + hidden_states=sample, temb=timestep_embed + ) + down_block_res_samples += res_samples + if self.mid_block: + sample = self.mid_block(sample, timestep_embed) + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-1:] + down_block_res_samples = down_block_res_samples[:-1] + sample = upsample_block( + sample, res_hidden_states_tuple=res_samples, temb=timestep_embed + ) + if self.out_block: + sample = self.out_block(sample, timestep_embed) + if not return_dict: + return (sample,) + return UNet1DOutput(sample=sample) + + +class DownResnetBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + num_layers=1, + conv_shortcut=False, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_downsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.add_downsample = add_downsample + self.output_scale_factor = output_scale_factor + if groups_out is None: + groups_out = groups + resnets = [ + ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels) + ] + for _ in range(num_layers): + resnets.append( + ResidualTemporalBlock1D( + out_channels, out_channels, embed_dim=temb_channels + ) + ) + self.resnets = nn.ModuleList(resnets) + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) + + def forward(self, hidden_states, temb=None): + output_states = () + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + if self.downsample is not None: + hidden_states = self.downsample(hidden_states) + return hidden_states, output_states + + +class UpResnetBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + num_layers=1, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.time_embedding_norm = time_embedding_norm + self.add_upsample = add_upsample + self.output_scale_factor = output_scale_factor + if groups_out is None: + groups_out = groups + resnets = [ + ResidualTemporalBlock1D( + 2 * in_channels, out_channels, embed_dim=temb_channels + ) + ] + for _ in range(num_layers): + resnets.append( + ResidualTemporalBlock1D( + out_channels, out_channels, embed_dim=temb_channels + ) + ) + self.resnets = nn.ModuleList(resnets) + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + self.upsample = None + if add_upsample: + self.upsample = Upsample1D(out_channels, use_conv_transpose=True) + + def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): + if res_hidden_states_tuple is not None: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + if self.upsample is not None: + hidden_states = self.upsample(hidden_states) + return hidden_states + + +class AttnDownBlock1D(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.down(hidden_states) + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + return hidden_states, (hidden_states,) + + +class DownBlock1D(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.down(hidden_states) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + return hidden_states, (hidden_states,) + + +class DownBlock1DNoSkip(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = torch.cat([hidden_states, temb], dim=1) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + return hidden_states, (hidden_states,) + + +class AttnUpBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + hidden_states = self.up(hidden_states) + return hidden_states + + +class UpBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + hidden_states = self.up(hidden_states) + return hidden_states + + +class UpBlock1DNoSkip(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), + ] + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + return hidden_states + + +@dataclass +class UNet2DOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states output. Output of last layer of model. + """ + + sample: "torch.FloatTensor" + + +class UNetMidBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + add_attention: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.add_attention = add_attention + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + for _ in range(num_layers): + if self.add_attention: + attentions.append( + AttentionBlock( + in_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + ) + ) + else: + attentions.append(None) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb) + return hidden_states + + +class UNet2DModel(ModelMixin, ConfigMixin): + """ + UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. + freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to : + obj:`True`): Whether to flip sin to cos for fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block + types. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): + The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(224, 448, 672, 896)`): Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. + mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. + downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. + norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization. + norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to None): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, or `"identity"`. + num_class_embeds (`int`, *optional*, defaults to None): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + """ + + @register_to_config + def __init__( + self, + sample_size: "Optional[Union[int, Tuple[int, int]]]" = None, + in_channels: "int" = 3, + out_channels: "int" = 3, + center_input_sample: "bool" = False, + time_embedding_type: "str" = "positional", + freq_shift: "int" = 0, + flip_sin_to_cos: "bool" = True, + down_block_types: "Tuple[str]" = ( + "DownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + ), + up_block_types: "Tuple[str]" = ( + "AttnUpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + ), + block_out_channels: "Tuple[int]" = (224, 448, 672, 896), + layers_per_block: "int" = 2, + mid_block_scale_factor: "float" = 1, + downsample_padding: "int" = 1, + act_fn: "str" = "silu", + attention_head_dim: "Optional[int]" = 8, + norm_num_groups: "int" = 32, + norm_eps: "float" = 1e-05, + resnet_time_scale_shift: "str" = "default", + add_attention: "bool" = True, + class_embed_type: "Optional[str]" = None, + num_class_embeds: "Optional[int]" = None, + ): + super().__init__() + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1) + ) + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection( + embedding_size=block_out_channels[0], scale=16 + ) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos, freq_shift + ) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attn_num_head_channels=attention_head_dim, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + attn_num_head_channels=attention_head_dim, + resnet_groups=norm_num_groups, + add_attention=add_attention, + ) + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + is_final_block = i == len(block_out_channels) - 1 + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attn_num_head_channels=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + num_groups_out = ( + norm_num_groups + if norm_num_groups is not None + else min(block_out_channels[0] // 4, 32) + ) + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps + ) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=3, padding=1 + ) + + def forward( + self, + sample: "torch.FloatTensor", + timestep: "Union[torch.Tensor, float, int]", + class_labels: "Optional[torch.Tensor]" = None, + return_dict: "bool" = True, + ) -> Union[UNet2DOutput, Tuple]: + """ + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor( + [timesteps], dtype=torch.long, device=sample.device + ) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None] + timesteps = timesteps * torch.ones( + sample.shape[0], dtype=timesteps.dtype, device=timesteps.device + ) + t_emb = self.time_proj(timesteps) + t_emb = t_emb + emb = self.time_embedding(t_emb) + if self.class_embedding is not None: + if class_labels is None: + raise ValueError( + "class_labels should be provided when doing class conditioning" + ) + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + class_emb = self.class_embedding(class_labels) + emb = emb + class_emb + skip_sample = sample + sample = self.conv_in(sample) + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "skip_conv"): + sample, res_samples, skip_sample = downsample_block( + hidden_states=sample, temb=emb, skip_sample=skip_sample + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + down_block_res_samples += res_samples + sample = self.mid_block(sample, emb) + skip_sample = None + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + if hasattr(upsample_block, "skip_conv"): + sample, skip_sample = upsample_block( + sample, res_samples, emb, skip_sample + ) + else: + sample = upsample_block(sample, res_samples, emb) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + if skip_sample is not None: + sample += skip_sample + if self.config.time_embedding_type == "fourier": + timesteps = timesteps.reshape( + (sample.shape[0], *([1] * len(sample.shape[1:]))) + ) + sample = sample / timesteps + if not return_dict: + return (sample,) + return UNet2DOutput(sample=sample) + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + hidden_states = resnet(hidden_states, temb) + return hidden_states + + +class UNetMidBlock2DSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + ): + super().__init__() + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.num_heads = in_channels // self.attn_num_head_channels + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + for _ in range(num_layers): + attentions.append( + CrossAttention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=attn_num_head_channels, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + processor=CrossAttnAddedKVProcessor(), + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + hidden_states = resnet(hidden_states, temb) + return hidden_states + + +class AttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None): + output_states = () + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states += (hidden_states,) + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + output_states = () + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states += (hidden_states,) + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states += (hidden_states,) + return hidden_states, output_states + + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + return hidden_states + + +class AttnDownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + attentions = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + return hidden_states + + +class AttnSkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=np.sqrt(2.0), + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + ) + ) + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList( + [FirDownsample2D(out_channels, out_channels=out_channels)] + ) + self.skip_conv = nn.Conv2d( + 3, out_channels, kernel_size=(1, 1), stride=(1, 1) + ) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + hidden_states = self.skip_conv(skip_sample) + hidden_states + output_states += (hidden_states,) + return hidden_states, output_states, skip_sample + + +class SkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_pre_norm: "bool" = True, + output_scale_factor=np.sqrt(2.0), + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList( + [FirDownsample2D(out_channels, out_channels=out_channels)] + ) + self.skip_conv = nn.Conv2d( + 3, out_channels, kernel_size=(1, 1), stride=(1, 1) + ) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + hidden_states = self.skip_conv(skip_sample) + hidden_states + output_states += (hidden_states,) + return hidden_states, output_states, skip_sample + + +class ResnetDownsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_downsample=True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + output_states += (hidden_states,) + return hidden_states, output_states + + +class SimpleCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_downsample=True, + ): + super().__init__() + self.has_cross_attention = True + resnets = [] + attentions = [] + self.attn_num_head_channels = attn_num_head_channels + self.num_heads = out_channels // self.attn_num_head_channels + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + CrossAttention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attn_num_head_channels, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + processor=CrossAttnAddedKVProcessor(), + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + output_states = () + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + output_states += (hidden_states,) + return hidden_states, output_states + + +class KDownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 4, + resnet_eps: "float" = 1e-05, + resnet_act_fn: "str" = "gelu", + resnet_group_size: "int" = 32, + add_downsample=False, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_downsample: + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + return hidden_states, output_states + + +class KAttentionBlock(nn.Module): + """ + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: "int", + num_attention_heads: "int", + attention_head_dim: "int", + dropout: "float" = 0.0, + cross_attention_dim: "Optional[int]" = None, + attention_bias: "bool" = False, + upcast_attention: "bool" = False, + temb_channels: "int" = 768, + add_self_attention: "bool" = False, + cross_attention_norm: "bool" = False, + group_size: "int" = 32, + ): + super().__init__() + self.add_self_attention = add_self_attention + if add_self_attention: + self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + cross_attention_norm=None, + ) + self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_attention_norm=cross_attention_norm, + ) + + def _to_3d(self, hidden_states, height, weight): + return hidden_states.permute(0, 2, 3, 1).reshape( + hidden_states.shape[0], height * weight, -1 + ) + + def _to_4d(self, hidden_states, height, weight): + return hidden_states.permute(0, 2, 1).reshape( + hidden_states.shape[0], -1, height, weight + ) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + emb=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + if self.add_self_attention: + norm_hidden_states = self.norm1(hidden_states, emb) + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + attn_output = self.attn1( + norm_hidden_states, encoder_hidden_states=None, **cross_attention_kwargs + ) + attn_output = self._to_4d(attn_output, height, weight) + hidden_states = attn_output + hidden_states + norm_hidden_states = self.norm2(hidden_states, emb) + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + hidden_states = attn_output + hidden_states + return hidden_states + + +class KCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + cross_attention_dim: "int", + dropout: "float" = 0.0, + num_layers: "int" = 4, + resnet_group_size: "int" = 32, + add_downsample=True, + attn_num_head_channels: "int" = 64, + add_self_attention: "bool" = False, + resnet_eps: "float" = 1e-05, + resnet_act_fn: "str" = "gelu", + ): + super().__init__() + resnets = [] + attentions = [] + self.has_cross_attention = True + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + out_channels, + out_channels // attn_num_head_channels, + attn_num_head_channels, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm=True, + group_size=resnet_group_size, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + if add_downsample: + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + output_states = () + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + attention_mask, + cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + if self.downsamplers is None: + output_states += (None,) + else: + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + return hidden_states, output_states + + +class AttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + prev_output_channel: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + for resnet, attn in zip(self.resnets, self.attentions): + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + prev_output_channel: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + cross_attention_kwargs=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + prev_output_channel: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None + ): + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + return hidden_states + + +class AttnUpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + def forward(self, hidden_states): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + return hidden_states + + +class AttnSkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + prev_output_channel: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=np.sqrt(2.0), + upsample_padding=1, + add_upsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions.append( + AttentionBlock( + out_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + ) + ) + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d( + out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), + num_channels=out_channels, + eps=resnet_eps, + affine=True, + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None + ): + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = self.attentions[0](hidden_states) + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + skip_sample = skip_sample + skip_sample_states + hidden_states = self.resnet_up(hidden_states, temb) + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + prev_output_channel: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_pre_norm: "bool" = True, + output_scale_factor=np.sqrt(2.0), + add_upsample=True, + upsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d( + out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), + num_channels=out_channels, + eps=resnet_eps, + affine=True, + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None + ): + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + skip_sample = skip_sample + skip_sample_states + hidden_states = self.resnet_up(hidden_states, temb) + return hidden_states, skip_sample + + +class ResnetUpsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + prev_output_channel: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None + ): + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + return hidden_states + + +class SimpleCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + prev_output_channel: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + self.num_heads = out_channels // self.attn_num_head_channels + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + CrossAttention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attn_num_head_channels, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + processor=CrossAttnAddedKVProcessor(), + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + for resnet, attn in zip(self.resnets, self.attentions): + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + return hidden_states + + +class KUpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 5, + resnet_eps: "float" = 1e-05, + resnet_act_fn: "str" = "gelu", + resnet_group_size: "Optional[int]" = 32, + add_upsample=True, + ): + super().__init__() + resnets = [] + k_in_channels = 2 * out_channels + k_out_channels = in_channels + num_layers = num_layers - 1 + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=k_out_channels + if i == num_layers - 1 + else out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None + ): + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + return hidden_states + + +class KCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: "int", + out_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 4, + resnet_eps: "float" = 1e-05, + resnet_act_fn: "str" = "gelu", + resnet_group_size: "int" = 32, + attn_num_head_channels=1, + cross_attention_dim: "int" = 768, + add_upsample: "bool" = True, + upcast_attention: "bool" = False, + ): + super().__init__() + resnets = [] + attentions = [] + is_first_block = in_channels == out_channels == temb_channels + is_middle_block = in_channels != out_channels + add_self_attention = True if is_first_block else False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + k_in_channels = out_channels if is_first_block else 2 * out_channels + k_out_channels = in_channels + num_layers = num_layers - 1 + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + if is_middle_block and i == num_layers - 1: + conv_2d_out_channels = k_out_channels + else: + conv_2d_out_channels = None + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + conv_2d_out_channels=conv_2d_out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + k_out_channels if i == num_layers - 1 else out_channels, + k_out_channels // attn_num_head_channels + if i == num_layers - 1 + else out_channels // attn_num_head_channels, + attn_num_head_channels, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm=True, + upcast_attention=upcast_attention, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + cross_attention_kwargs=None, + upsample_size=None, + attention_mask=None, + ): + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + attention_mask, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + return hidden_states + + +AttnProcessor = Union[ + CrossAttnProcessor, + XFormersCrossAttnProcessor, + SlicedAttnProcessor, + CrossAttnAddedKVProcessor, + SlicedAttnAddedKVProcessor, + LoRACrossAttnProcessor, + LoRAXFormersCrossAttnProcessor, +] + +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" + +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" + + +class UNet2DConditionLoadersMixin: + def load_attn_procs( + self, + pretrained_model_name_or_path_or_dict: "Union[str, Dict[str, torch.Tensor]]", + **kwargs, + ): + """ + Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be + defined in + [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) + and be a `torch.nn.Module` class. + + + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + if ( + is_safetensors_available() + and weight_name is None + or weight_name is not None + and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except EnvironmentError: + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + attn_processors = {} + is_lora = all("lora" in k for k in state_dict.keys()) + if is_lora: + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join( + key.split(".")[-3:] + ) + lora_grouped_dict[attn_processor_key][sub_key] = value + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + attn_processors[key] = LoRACrossAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=rank, + ) + attn_processors[key].load_state_dict(value_dict) + else: + raise ValueError( + f"{model_file} does not seem to be in the correct format expected by LoRA training." + ) + attn_processors = {k: v for k, v in attn_processors.items()} + self.set_attn_processor(attn_processors) + + def save_attn_procs( + self, + save_directory: "Union[str, os.PathLike]", + is_main_process: "bool" = True, + weight_name: "str" = None, + save_function: "Callable" = None, + safe_serialization: "bool" = False, + **kwargs, + ): + """ + Save an attention processor to a directory, so that it can be re-loaded using the + `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + """ + weight_name = weight_name or deprecate( + "weights_name", + "0.18.0", + "`weights_name` is deprecated, please use `weight_name` instead.", + take_from=kwargs, + ) + if os.path.isfile(save_directory): + logger.error( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + return + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file( + weights, filename, metadata={"format": "pt"} + ) + + else: + save_function = torch.save + os.makedirs(save_directory, exist_ok=True) + model_to_save = AttnProcsLayers(self.attn_processors) + state_dict = model_to_save.state_dict() + if weight_name is None: + if safe_serialization: + weight_name = LORA_WEIGHT_NAME_SAFE + else: + weight_name = LORA_WEIGHT_NAME + save_function(state_dict, os.path.join(save_directory, weight_name)) + logger.info( + f"Model weights saved in {os.path.join(save_directory, weight_name)}" + ) + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: "torch.FloatTensor" + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + """ + UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the + mid block layer if `None`. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to None): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, or `"projection"`. + num_class_embeds (`int`, *optional*, defaults to None): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, default to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + timestep_post_act (`str, *optional*, default to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, default to `None`): + The dimension of `cond_proj` layer in timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: "Optional[int]" = None, + in_channels: "int" = 4, + out_channels: "int" = 4, + center_input_sample: "bool" = False, + flip_sin_to_cos: "bool" = True, + freq_shift: "int" = 0, + down_block_types: "Tuple[str]" = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: "Optional[str]" = "UNetMidBlock2DCrossAttn", + up_block_types: "Tuple[str]" = ( + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + ), + only_cross_attention: "Union[bool, Tuple[bool]]" = False, + block_out_channels: "Tuple[int]" = (320, 640, 1280, 1280), + layers_per_block: "int" = 2, + downsample_padding: "int" = 1, + mid_block_scale_factor: "float" = 1, + act_fn: "str" = "silu", + norm_num_groups: "Optional[int]" = 32, + norm_eps: "float" = 1e-05, + cross_attention_dim: "int" = 1280, + attention_head_dim: "Union[int, Tuple[int]]" = 8, + dual_cross_attention: "bool" = False, + use_linear_projection: "bool" = False, + class_embed_type: "Optional[str]" = None, + num_class_embeds: "Optional[int]" = None, + upcast_attention: "bool" = False, + resnet_time_scale_shift: "str" = "default", + time_embedding_type: "str" = "positional", + timestep_post_act: "Optional[str]" = None, + time_cond_proj_dim: "Optional[int]" = None, + conv_in_kernel: "int" = 3, + conv_out_kernel: "int" = 3, + projection_class_embeddings_input_dim: "Optional[int]" = None, + ): + super().__init__() + self.sample_size = sample_size + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + if not isinstance(only_cross_attention, bool) and len( + only_cross_attention + ) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( + down_block_types + ): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding, + ) + if time_embedding_type == "fourier": + time_embed_dim = block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError( + f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." + ) + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, + set_W_to_weight=False, + log=False, + flip_sin_to_cos=flip_sin_to_cos, + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos, freq_shift + ) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." + ) + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim + ) + else: + self.class_embedding = None + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + self.num_upsamplers = 0 + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + kernel_size=conv_out_kernel, + padding=conv_out_padding, + ) + + @property + def attn_processors(self) -> Dict[str, AttnProcessor]: + """ + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + processors = {} + + def fn_recursive_add_processors( + name: "str", + module: "torch.nn.Module", + processors: "Dict[str, AttnProcessor]", + ): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + return processors + + def set_attn_processor( + self, processor: "Union[AttnProcessor, Dict[str, AttnProcessor]]" + ): + """ + Parameters: + `processor (`dict` of `AttnProcessor` or `AttnProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + of **all** `CrossAttention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: + + """ + count = len(self.attn_processors.keys()) + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor( + name: "str", module: "torch.nn.Module", processor + ): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_attention_slice(self, slice_size): + """ + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: "torch.nn.Module"): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + num_slicable_layers = len(sliceable_head_dims) + if slice_size == "auto": + slice_size = [(dim // 2) for dim in sliceable_head_dims] + elif slice_size == "max": + slice_size = num_slicable_layers * [1] + slice_size = ( + num_slicable_layers * [slice_size] + if not isinstance(slice_size, list) + else slice_size + ) + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + def fn_recursive_set_attention_slice( + module: "torch.nn.Module", slice_size: "List[int]" + ): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance( + module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D) + ): + module.gradient_checkpointing = value + + def forward( + self, + sample: "torch.FloatTensor", + timestep: "Union[torch.Tensor, float, int]", + encoder_hidden_states: "torch.Tensor", + class_labels: "Optional[torch.Tensor]" = None, + timestep_cond: "Optional[torch.Tensor]" = None, + attention_mask: "Optional[torch.Tensor]" = None, + cross_attention_kwargs: "Optional[Dict[str, Any]]" = None, + down_block_additional_residuals: "Optional[Tuple[torch.Tensor]]" = None, + mid_block_additional_residual: "Optional[torch.Tensor]" = None, + return_dict: "bool" = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + """ + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + default_overall_up_factor = 2 ** self.num_upsamplers + forward_upsample_size = False + upsample_size = None + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + if attention_mask is not None: + attention_mask = (1 - attention_mask) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + timesteps = timestep + if not torch.is_tensor(timesteps): + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None] + timesteps = timesteps.expand(sample.shape[0]) + t_emb = self.time_proj(timesteps) + t_emb = t_emb + emb = self.time_embedding(t_emb, timestep_cond) + if self.class_embedding is not None: + if class_labels is None: + raise ValueError( + "class_labels should be provided when num_class_embeds > 0" + ) + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + class_emb = self.class_embedding(class_labels) + emb = emb + class_emb + sample = self.conv_in(sample) + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + down_block_res_samples += res_samples + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples += (down_block_res_sample,) + down_block_res_samples = new_down_block_res_samples + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + if not return_dict: + return (sample,) + return UNet2DConditionOutput(sample=sample) + + +def rename_key(key): + regex = "\\w+[.]\\d+" + pats = re.findall(regex, key) + for pat in pats: + key = key.replace(pat, "_".join(pat.split("."))) + return key + + +class Encoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + double_z=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.conv_in = torch.nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1 + ) + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-06, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attn_num_head_channels=None, + temb_channels=None, + ) + self.down_blocks.append(down_block) + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-06, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=norm_num_groups, + temb_channels=None, + ) + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-06 + ) + self.conv_act = nn.SiLU() + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d( + block_out_channels[-1], conv_out_channels, 3, padding=1 + ) + + def forward(self, x): + sample = x + sample = self.conv_in(sample) + for down_block in self.down_blocks: + sample = down_block(sample) + sample = self.mid_block(sample) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + return sample + + +class Decoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + ): + super().__init__() + self.layers_per_block = layers_per_block + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1 + ) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-06, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=norm_num_groups, + temb_channels=None, + ) + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-06, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attn_num_head_channels=None, + temb_channels=None, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-06 + ) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + def forward(self, z): + sample = z + sample = self.conv_in(sample) + sample = self.mid_block(sample) + for up_block in self.up_blocks: + sample = up_block(sample) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + return sample + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + def __init__( + self, + n_e, + vq_embed_dim, + beta, + remap=None, + unknown_index="random", + sane_index_shape=False, + legacy=True, + ): + super().__init__() + self.n_e = n_e + self.vq_embed_dim = vq_embed_dim + self.beta = beta + self.legacy = legacy + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + None + else: + self.re_embed = n_e + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used + if self.re_embed > self.used.shape[0]: + inds[inds >= self.used.shape[0]] = 0 + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.vq_embed_dim) + min_encoding_indices = torch.argmin( + torch.cdist(z_flattened, self.embedding.weight), dim=1 + ) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean( + (z_q - z.detach()) ** 2 + ) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean( + (z_q - z.detach()) ** 2 + ) + z_q = z + (z_q - z).detach() + z_q = z_q.permute(0, 3, 1, 2).contiguous() + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3] + ) + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + if self.remap is not None: + indices = indices.reshape(shape[0], -1) + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) + z_q = self.embedding(indices) + if shape is not None: + z_q = z_q.view(shape) + z_q = z_q.permute(0, 3, 1, 2).contiguous() + return z_q + + +@dataclass +class DecoderOutput(BaseOutput): + """ + Output of decoding method. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Decoded output sample of the model. Output of the last layer of the model. + """ + + sample: "torch.FloatTensor" + + +@dataclass +class VQEncoderOutput(BaseOutput): + """ + Output of VQModel encoding method. + + Args: + latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Encoded output sample of the model. Output of the last layer of the model. + """ + + latents: "torch.FloatTensor" + + +class VQModel(ModelMixin, ConfigMixin): + """VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray + Kavukcuoglu. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. + scaling_factor (`float`, *optional*, defaults to `0.18215`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + """ + + @register_to_config + def __init__( + self, + in_channels: "int" = 3, + out_channels: "int" = 3, + down_block_types: "Tuple[str]" = ("DownEncoderBlock2D",), + up_block_types: "Tuple[str]" = ("UpDecoderBlock2D",), + block_out_channels: "Tuple[int]" = (64,), + layers_per_block: "int" = 1, + act_fn: "str" = "silu", + latent_channels: "int" = 3, + sample_size: "int" = 32, + num_vq_embeddings: "int" = 256, + norm_num_groups: "int" = 32, + vq_embed_dim: "Optional[int]" = None, + scaling_factor: "float" = 0.18215, + ): + super().__init__() + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=False, + ) + vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels + self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) + self.quantize = VectorQuantizer( + num_vq_embeddings, + vq_embed_dim, + beta=0.25, + remap=None, + sane_index_shape=False, + ) + self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + ) + + def encode( + self, x: "torch.FloatTensor", return_dict: "bool" = True + ) -> VQEncoderOutput: + h = self.encoder(x) + h = self.quant_conv(h) + if not return_dict: + return (h,) + return VQEncoderOutput(latents=h) + + def decode( + self, + h: "torch.FloatTensor", + force_not_quantize: "bool" = False, + return_dict: "bool" = True, + ) -> Union[DecoderOutput, torch.FloatTensor]: + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, sample: "torch.FloatTensor", return_dict: "bool" = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Args: + sample (`torch.FloatTensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + h = self.encode(x).latents + dec = self.decode(h).sample + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + +class LDMBertAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: "int", + num_heads: "int", + head_dim: "int", + dropout: "float" = 0.0, + is_decoder: "bool" = False, + bias: "bool" = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = head_dim + self.inner_dim = head_dim * num_heads + self.scaling = self.head_dim ** -0.5 + self.is_decoder = is_decoder + self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.out_proj = nn.Linear(self.inner_dim, embed_dim) + + def _shape(self, tensor: "torch.Tensor", seq_len: "int", bsz: "int"): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: "torch.Tensor", + key_value_states: "Optional[torch.Tensor]" = None, + past_key_value: "Optional[Tuple[torch.Tensor]]" = None, + attention_mask: "Optional[torch.Tensor]" = None, + layer_head_mask: "Optional[torch.Tensor]" = None, + output_attentions: "bool" = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) * self.scaling + if is_cross_attention and past_key_value is not None: + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + if self.is_decoder: + past_key_value = key_states, value_states + proj_shape = bsz * self.num_heads, -1, self.head_dim + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {bsz * self.num_heads, tgt_len, src_len}, but is {attn_weights.size()}" + ) + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {bsz, 1, tgt_len, src_len}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {self.num_heads,}, but is {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + if output_attentions: + attn_weights_reshaped = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights_reshaped.view( + bsz * self.num_heads, tgt_len, src_len + ) + else: + attn_weights_reshaped = None + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.bmm(attn_probs, value_states) + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {bsz, self.num_heads, tgt_len, self.head_dim}, but is {attn_output.size()}" + ) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) + attn_output = self.out_proj(attn_output) + return attn_output, attn_weights_reshaped, past_key_value + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +# ACT2CLS = { +# "gelu": GELUActivation, +# "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), +# "gelu_fast": FastGELUActivation, +# "gelu_new": NewGELUActivation, +# "gelu_python": (GELUActivation, {"use_gelu_python": True}), +# "gelu_pytorch_tanh": PytorchGELUTanh, +# "linear": LinearActivation, +# "mish": MishActivation, +# "quick_gelu": QuickGELUActivation, +# "relu": nn.ReLU, +# "relu6": nn.ReLU6, +# "sigmoid": nn.Sigmoid, +# "silu": SiLUActivation, +# "swish": SiLUActivation, +# "tanh": nn.Tanh, +# } +# ACT2FN = ClassInstantier(ACT2CLS) + + +class LDMBertEncoderLayer(nn.Module): + def __init__(self, config: "LDMBertConfig"): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = LDMBertAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + head_dim=config.head_dim, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: "torch.FloatTensor", + attention_mask: "torch.FloatTensor", + layer_head_mask: "torch.FloatTensor", + output_attentions: "Optional[bool]" = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout( + hidden_states, p=self.activation_dropout, training=self.training + ) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + outputs = (hidden_states,) + if output_attentions: + outputs += (attn_weights,) + return outputs + + +class PaintByExampleMapper(nn.Module): + def __init__(self, config): + super().__init__() + num_layers = (config.num_hidden_layers + 1) // 5 + hid_size = config.hidden_size + num_heads = 1 + self.blocks = nn.ModuleList( + [ + BasicTransformerBlock( + hid_size, + num_heads, + hid_size, + activation_fn="gelu", + attention_bias=True, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states): + for block in self.blocks: + hidden_states = block(hidden_states) + return hidden_states + + +class GaussianSmoothing(torch.nn.Module): + """ + Arguments: + Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input + using a depthwise convolution. + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the + gaussian kernel. dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + + def __init__( + self, + channels: "int" = 1, + kernel_size: "int" = 3, + sigma: "float" = 0.5, + dim: "int" = 2, + ): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = [kernel_size] * dim + if isinstance(sigma, float): + sigma = [sigma] * dim + kernel = 1 + meshgrids = torch.meshgrid( + [torch.arange(size, dtype=torch.float32) for size in kernel_size] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= ( + 1 + / (std * math.sqrt(2 * math.pi)) + * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) + ) + kernel = kernel / torch.sum(kernel) + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *([1] * (kernel.dim() - 1))) + self.register_buffer("weight", kernel) + self.groups = channels + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError( + "Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim) + ) + + def forward(self, input): + """ + Arguments: + Apply gaussian filter to input. + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv(input, weight=self.weight, groups=self.groups) + + +class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): + """ + This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. + + It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image + embeddings. + """ + + @register_to_config + def __init__(self, embedding_dim: "int" = 768): + super().__init__() + self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) + self.std = nn.Parameter(torch.ones(1, embedding_dim)) + + def scale(self, embeds): + embeds = (embeds - self.mean) * 1.0 / self.std + return embeds + + def unscale(self, embeds): + embeds = embeds * self.std + self.mean + return embeds + + +class UnCLIPTextProjModel(ModelMixin, ConfigMixin): + """ + Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the + decoder. + + For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1 + """ + + @register_to_config + def __init__( + self, + *, + clip_extra_context_tokens: int = 4, + clip_embeddings_dim: int = 768, + time_embed_dim: int, + cross_attention_dim, + ): + super().__init__() + self.learned_classifier_free_guidance_embeddings = nn.Parameter( + torch.zeros(clip_embeddings_dim) + ) + self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim) + self.clip_image_embeddings_project_to_time_embeddings = nn.Linear( + clip_embeddings_dim, time_embed_dim + ) + self.clip_extra_context_tokens = clip_extra_context_tokens + self.clip_extra_context_tokens_proj = nn.Linear( + clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim + ) + self.encoder_hidden_states_proj = nn.Linear( + clip_embeddings_dim, cross_attention_dim + ) + self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim) + + def forward( + self, + *, + image_embeddings, + prompt_embeds, + text_encoder_hidden_states, + do_classifier_free_guidance, + ): + if do_classifier_free_guidance: + image_embeddings_batch_size = image_embeddings.shape[0] + classifier_free_guidance_embeddings = ( + self.learned_classifier_free_guidance_embeddings.unsqueeze(0) + ) + classifier_free_guidance_embeddings = ( + classifier_free_guidance_embeddings.expand( + image_embeddings_batch_size, -1 + ) + ) + image_embeddings = torch.cat( + [classifier_free_guidance_embeddings, image_embeddings], dim=0 + ) + assert image_embeddings.shape[0] == prompt_embeds.shape[0] + batch_size = prompt_embeds.shape[0] + time_projected_prompt_embeds = self.embedding_proj(prompt_embeds) + time_projected_image_embeddings = ( + self.clip_image_embeddings_project_to_time_embeddings(image_embeddings) + ) + additive_clip_time_embeddings = ( + time_projected_image_embeddings + time_projected_prompt_embeds + ) + clip_extra_context_tokens = self.clip_extra_context_tokens_proj( + image_embeddings + ) + clip_extra_context_tokens = clip_extra_context_tokens.reshape( + batch_size, -1, self.clip_extra_context_tokens + ) + text_encoder_hidden_states = self.encoder_hidden_states_proj( + text_encoder_hidden_states + ) + text_encoder_hidden_states = self.text_encoder_hidden_states_norm( + text_encoder_hidden_states + ) + text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1) + text_encoder_hidden_states = torch.cat( + [clip_extra_context_tokens, text_encoder_hidden_states], dim=2 + ) + return text_encoder_hidden_states, additive_clip_time_embeddings + + +class UNetMidBlockFlatCrossAttn(nn.Module): + def __init__( + self, + in_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + hidden_states = resnet(hidden_states, temb) + return hidden_states + + +class UNetMidBlockFlatSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: "int", + temb_channels: "int", + dropout: "float" = 0.0, + num_layers: "int" = 1, + resnet_eps: "float" = 1e-06, + resnet_time_scale_shift: "str" = "default", + resnet_act_fn: "str" = "swish", + resnet_groups: "int" = 32, + resnet_pre_norm: "bool" = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + ): + super().__init__() + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.num_heads = in_channels // self.attn_num_head_channels + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + for _ in range(num_layers): + attentions.append( + CrossAttention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=attn_num_head_channels, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + processor=CrossAttnAddedKVProcessor(), + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + hidden_states = resnet(hidden_states, temb) + return hidden_states + + +class UNetFlatConditionModel(ModelMixin, ConfigMixin): + """ + UNetFlatConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a + timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): + The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`, will skip + the mid block layer if `None`. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to None): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, or `"projection"`. + num_class_embeds (`int`, *optional*, defaults to None): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, default to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + timestep_post_act (`str, *optional*, default to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, default to `None`): + The dimension of `cond_proj` layer in timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: "Optional[int]" = None, + in_channels: "int" = 4, + out_channels: "int" = 4, + center_input_sample: "bool" = False, + flip_sin_to_cos: "bool" = True, + freq_shift: "int" = 0, + down_block_types: "Tuple[str]" = ( + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "DownBlockFlat", + ), + mid_block_type: "Optional[str]" = "UNetMidBlockFlatCrossAttn", + up_block_types: "Tuple[str]" = ( + "UpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + ), + only_cross_attention: "Union[bool, Tuple[bool]]" = False, + block_out_channels: "Tuple[int]" = (320, 640, 1280, 1280), + layers_per_block: "int" = 2, + downsample_padding: "int" = 1, + mid_block_scale_factor: "float" = 1, + act_fn: "str" = "silu", + norm_num_groups: "Optional[int]" = 32, + norm_eps: "float" = 1e-05, + cross_attention_dim: "int" = 1280, + attention_head_dim: "Union[int, Tuple[int]]" = 8, + dual_cross_attention: "bool" = False, + use_linear_projection: "bool" = False, + class_embed_type: "Optional[str]" = None, + num_class_embeds: "Optional[int]" = None, + upcast_attention: "bool" = False, + resnet_time_scale_shift: "str" = "default", + time_embedding_type: "str" = "positional", + timestep_post_act: "Optional[str]" = None, + time_cond_proj_dim: "Optional[int]" = None, + conv_in_kernel: "int" = 3, + conv_out_kernel: "int" = 3, + projection_class_embeddings_input_dim: "Optional[int]" = None, + ): + super().__init__() + self.sample_size = sample_size + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + if not isinstance(only_cross_attention, bool) and len( + only_cross_attention + ) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( + down_block_types + ): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = LinearMultiDim( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding, + ) + if time_embedding_type == "fourier": + time_embed_dim = block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError( + f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." + ) + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, + set_W_to_weight=False, + log=False, + flip_sin_to_cos=flip_sin_to_cos, + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos, freq_shift + ) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." + ) + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim + ) + else: + self.class_embedding = None + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + if mid_block_type == "UNetMidBlockFlatCrossAttn": + self.mid_block = UNetMidBlockFlatCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn": + self.mid_block = UNetMidBlockFlatSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + self.num_upsamplers = 0 + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = LinearMultiDim( + block_out_channels[0], + out_channels, + kernel_size=conv_out_kernel, + padding=conv_out_padding, + ) + + @property + def attn_processors(self) -> Dict[str, AttnProcessor]: + """ + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + processors = {} + + def fn_recursive_add_processors( + name: "str", + module: "torch.nn.Module", + processors: "Dict[str, AttnProcessor]", + ): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + return processors + + def set_attn_processor( + self, processor: "Union[AttnProcessor, Dict[str, AttnProcessor]]" + ): + """ + Parameters: + `processor (`dict` of `AttnProcessor` or `AttnProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + of **all** `CrossAttention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: + + """ + count = len(self.attn_processors.keys()) + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor( + name: "str", module: "torch.nn.Module", processor + ): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_attention_slice(self, slice_size): + """ + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: "torch.nn.Module"): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + num_slicable_layers = len(sliceable_head_dims) + if slice_size == "auto": + slice_size = [(dim // 2) for dim in sliceable_head_dims] + elif slice_size == "max": + slice_size = num_slicable_layers * [1] + slice_size = ( + num_slicable_layers * [slice_size] + if not isinstance(slice_size, list) + else slice_size + ) + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + def fn_recursive_set_attention_slice( + module: "torch.nn.Module", slice_size: "List[int]" + ): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance( + module, + (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat), + ): + module.gradient_checkpointing = value + + def forward( + self, + sample: "torch.FloatTensor", + timestep: "Union[torch.Tensor, float, int]", + encoder_hidden_states: "torch.Tensor", + class_labels: "Optional[torch.Tensor]" = None, + timestep_cond: "Optional[torch.Tensor]" = None, + attention_mask: "Optional[torch.Tensor]" = None, + cross_attention_kwargs: "Optional[Dict[str, Any]]" = None, + down_block_additional_residuals: "Optional[Tuple[torch.Tensor]]" = None, + mid_block_additional_residual: "Optional[torch.Tensor]" = None, + return_dict: "bool" = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + """ + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + default_overall_up_factor = 2 ** self.num_upsamplers + forward_upsample_size = False + upsample_size = None + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + if attention_mask is not None: + attention_mask = (1 - attention_mask) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + timesteps = timestep + if not torch.is_tensor(timesteps): + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None] + timesteps = timesteps.expand(sample.shape[0]) + t_emb = self.time_proj(timesteps) + t_emb = t_emb + emb = self.time_embedding(t_emb, timestep_cond) + if self.class_embedding is not None: + if class_labels is None: + raise ValueError( + "class_labels should be provided when num_class_embeds > 0" + ) + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + class_emb = self.class_embedding(class_labels) + emb = emb + class_emb + sample = self.conv_in(sample) + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + down_block_res_samples += res_samples + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples += (down_block_res_sample,) + down_block_res_samples = new_down_block_res_samples + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + if not return_dict: + return (sample,) + return UNet2DConditionOutput(sample=sample) + + +class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): + """ + Utility class for storing learned text embeddings for classifier free sampling + """ + + @register_to_config + def __init__( + self, + learnable: "bool", + hidden_size: "Optional[int]" = None, + length: "Optional[int]" = None, + ): + super().__init__() + self.learnable = learnable + if self.learnable: + assert ( + hidden_size is not None + ), "learnable=True requires `hidden_size` to be set" + assert length is not None, "learnable=True requires `length` to be set" + embeddings = torch.zeros(length, hidden_size) + else: + embeddings = None + self.embeddings = torch.nn.Parameter(embeddings) diff --git a/pi/models/unet/__init__.py b/pi/models/unet/__init__.py deleted file mode 100644 index 1d34370..0000000 --- a/pi/models/unet/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .unet_2d_condition import UNet2DConditionModel \ No newline at end of file diff --git a/pi/models/unet/attention.py b/pi/models/unet/attention.py deleted file mode 100644 index 1bd0215..0000000 --- a/pi/models/unet/attention.py +++ /dev/null @@ -1,519 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math -from typing import Callable, Optional - -from ... import nn -from ... import pi -from ...nn import functional as F - -from .cross_attention import CrossAttention -from .embeddings import CombinedTimestepLabelEmbeddings - - -xformers = None - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted - to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention. - - Parameters: - channels (`int`): The number of channels in the input and output. - num_head_channels (`int`, *optional*): - The number of channels in each head. If None, then `num_heads` = 1. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. - """ - - # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore - - def __init__( - self, - channels: int, - num_head_channels: Optional[int] = None, - norm_num_groups: int = 32, - rescale_output_factor: float = 1.0, - eps: float = 1e-5, - ): - super().__init__() - self.channels = channels - - self.num_heads = ( - channels // num_head_channels if num_head_channels is not None else 1 - ) - self.num_head_size = num_head_channels - self.group_norm = nn.GroupNorm( - num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True - ) - - # define q,k,v as linear layers - self.query = nn.Linear(channels, channels) - self.key = nn.Linear(channels, channels) - self.value = nn.Linear(channels, channels) - - self.rescale_output_factor = rescale_output_factor - self.proj_attn = nn.Linear(channels, channels, 1) - - self._use_memory_efficient_attention_xformers = False - self._attention_op = None - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape( - batch_size * head_size, seq_len, dim // head_size - ) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape( - batch_size // head_size, seq_len, dim * head_size - ) - return tensor - - def set_use_memory_efficient_attention_xformers( - self, - use_memory_efficient_attention_xformers: bool, - attention_op: Optional[Callable] = None, - ): - # if use_memory_efficient_attention_xformers: - # if not is_xformers_available(): - # raise ModuleNotFoundError( - # ( - # "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - # " xformers" - # ), - # name="xformers", - # ) - # elif not pi.cuda.is_available(): - # raise ValueError( - # "pi.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - # " only available for GPU " - # ) - # else: - # try: - # # Make sure we can run the memory efficient attention - # _ = xformers.ops.memory_efficient_attention( - # pi.randn((1, 2, 40), device="cuda"), - # pi.randn((1, 2, 40), device="cuda"), - # pi.randn((1, 2, 40), device="cuda"), - # ) - # except Exception as e: - # raise e - # self._use_memory_efficient_attention_xformers = ( - # use_memory_efficient_attention_xformers - # ) - # self._attention_op = attention_op - raise NotImplementedError - - def forward(self, hidden_states): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose( - 1, 2 - ) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - scale = 1 / math.sqrt(self.channels / self.num_heads) - - query_proj = self.reshape_heads_to_batch_dim(query_proj) - key_proj = self.reshape_heads_to_batch_dim(key_proj) - value_proj = self.reshape_heads_to_batch_dim(value_proj) - - if self._use_memory_efficient_attention_xformers: - # Memory efficient attention - hidden_states = xformers.ops.memory_efficient_attention( - query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op - ) - hidden_states = hidden_states.to(query_proj.dtype) - else: - attention_scores = pi.baddbmm( - pi.empty( - query_proj.shape[0], - query_proj.shape[1], - key_proj.shape[1], - dtype=query_proj.dtype, - device=query_proj.device, - ), - query_proj, - key_proj.transpose(-1, -2), - beta=0, - alpha=scale, - ) - attention_probs = pi.softmax(attention_scores.float(), dim=-1).type( - attention_scores.dtype - ) - hidden_states = pi.bmm(attention_probs, value_proj) - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch, channel, height, width - ) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - -class BasicTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - final_dropout: bool = False, - ): - super().__init__() - self.only_cross_attention = only_cross_attention - - self.use_ada_layer_norm_zero = ( - num_embeds_ada_norm is not None - ) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = ( - num_embeds_ada_norm is not None - ) and norm_type == "ada_norm" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - # 1. Self-Attn - self.attn1 = CrossAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) - - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - ) - - # 2. Cross-Attn - if cross_attention_dim is not None: - self.attn2 = CrossAttention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) # is self-attn if encoder_hidden_states is none - else: - self.attn2 = None - - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif self.use_ada_layer_norm_zero: - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - - if cross_attention_dim is not None: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - ) - else: - self.norm2 = None - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - timestep=None, - attention_mask=None, - cross_attention_kwargs=None, - class_labels=None, - ): - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - else: - norm_hidden_states = self.norm1(hidden_states) - - # 1. Self-Attention - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states - if self.only_cross_attention - else None, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = attn_output + hidden_states - - if self.attn2 is not None: - norm_hidden_states = ( - self.norm2(hidden_states, timestep) - if self.use_ada_layer_norm - else self.norm2(hidden_states) - ) - - # 2. Cross-Attention - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 3. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self.use_ada_layer_norm_zero: - norm_hidden_states = ( - norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - ) - - ff_output = self.ff(norm_hidden_states) - - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = ff_output + hidden_states - - return hidden_states - - -class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - """ - - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - mult: int = 4, - dropout: float = 0.0, - activation_fn: str = "geglu", - final_dropout: bool = False, - ): - super().__init__() - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - - if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim) - if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh") - elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim) - elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim) - - self.net = nn.ModuleList([]) - # project in - self.net.append(act_fn) - # project dropout - self.net.append(nn.Dropout(dropout)) - # project out - self.net.append(nn.Linear(inner_dim, dim_out)) - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - if final_dropout: - self.net.append(nn.Dropout(dropout)) - - def forward(self, hidden_states): - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states - - -class GELU(nn.Module): - r""" - GELU activation function with tanh approximation support with `approximate="tanh"`. - """ - - def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out) - self.approximate = approximate - - def gelu(self, gate): - if gate.device.type != "mps": - return F.gelu(gate, approximate=self.approximate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=pi.float32), approximate=self.approximate).to( - dtype=gate.dtype - ) - - def forward(self, hidden_states): - hidden_states = self.proj(hidden_states) - hidden_states = self.gelu(hidden_states) - return hidden_states - - -class GEGLU(nn.Module): - r""" - A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - """ - - def __init__(self, dim_in: int, dim_out: int): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def gelu(self, gate): - if gate.device.type != "mps": - return F.gelu(gate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=pi.float32)).to(dtype=gate.dtype) - - def forward(self, hidden_states): - hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) - return hidden_states * self.gelu(gate) - - -class ApproximateGELU(nn.Module): - """ - The approximate form of Gaussian Error Linear Unit (GELU) - - For more details, see section 2: https://arxiv.org/abs/1606.08415 - """ - - def __init__(self, dim_in: int, dim_out: int): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out) - - def forward(self, x): - x = self.proj(x) - return x * pi.sigmoid(1.702 * x) - - -class AdaLayerNorm(nn.Module): - """ - Norm layer modified to incorporate timestep embeddings. - """ - - def __init__(self, embedding_dim, num_embeddings): - super().__init__() - self.emb = nn.Embedding(num_embeddings, embedding_dim) - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, embedding_dim * 2) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) - - def forward(self, x, timestep): - emb = self.linear(self.silu(self.emb(timestep))) - scale, shift = pi.chunk(emb, 2) - x = self.norm(x) * (1 + scale) + shift - return x - - -class AdaLayerNormZero(nn.Module): - """ - Norm layer adaptive layer norm zero (adaLN-Zero). - """ - - def __init__(self, embedding_dim, num_embeddings): - super().__init__() - - self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) - - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) - - def forward(self, x, timestep, class_labels, hidden_dtype=None): - emb = self.linear( - self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)) - ) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk( - 6, dim=1 - ) - x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] - return x, gate_msa, shift_mlp, scale_mlp, gate_mlp diff --git a/pi/models/unet/cross_attention.py b/pi/models/unet/cross_attention.py deleted file mode 100644 index dc15f86..0000000 --- a/pi/models/unet/cross_attention.py +++ /dev/null @@ -1,679 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Callable, Optional, Union - -from ... import nn -from ... import pi -from ...nn import functional as F - - -xformers = None - - -class CrossAttention(nn.Module): - r""" - A cross attention layer. - - Parameters: - query_dim (`int`): The number of channels in the query. - cross_attention_dim (`int`, *optional*): - The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. - heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. - dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - bias (`bool`, *optional*, defaults to False): - Set to `True` for the query, key, and value linear layers to contain a bias parameter. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias=False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - added_kv_proj_dim: Optional[int] = None, - norm_num_groups: Optional[int] = None, - processor: Optional["AttnProcessor"] = None, - ): - super().__init__() - inner_dim = dim_head * heads - cross_attention_dim = ( - cross_attention_dim if cross_attention_dim is not None else query_dim - ) - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - - self.scale = dim_head ** -0.5 - - self.heads = heads - # for slice_size > 0 the attention score computation - # is split across the batch axis to save memory - # You can set slice_size with `set_attention_slice` - self.sliceable_head_dim = heads - - self.added_kv_proj_dim = added_kv_proj_dim - - if norm_num_groups is not None: - self.group_norm = nn.GroupNorm( - num_channels=inner_dim, - num_groups=norm_num_groups, - eps=1e-5, - affine=True, - ) - else: - self.group_norm = None - - self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(inner_dim, query_dim)) - self.to_out.append(nn.Dropout(dropout)) - - # set attention processor - processor = processor if processor is not None else CrossAttnProcessor() - self.set_processor(processor) - - def set_use_memory_efficient_attention_xformers( - self, - use_memory_efficient_attention_xformers: bool, - attention_op: Optional[Callable] = None, - ): - if use_memory_efficient_attention_xformers: - # if self.added_kv_proj_dim is not None: - # # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP - # # which uses this type of cross attention ONLY because the attention mask of format - # # [0, ..., -10.000, ..., 0, ...,] is not supported - # raise NotImplementedError( - # "Memory efficient attention with `xformers` is currently not supported when" - # " `self.added_kv_proj_dim` is defined." - # ) - # elif not is_xformers_available(): - # raise ModuleNotFoundError( - # ( - # "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - # " xformers" - # ), - # name="xformers", - # ) - # elif not pi.cuda.is_available(): - # raise ValueError( - # "pi.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - # " only available for GPU " - # ) - # else: - # try: - # # Make sure we can run the memory efficient attention - # _ = xformers.ops.memory_efficient_attention( - # pi.randn((1, 2, 40), device="cuda"), - # pi.randn((1, 2, 40), device="cuda"), - # pi.randn((1, 2, 40), device="cuda"), - # ) - # except Exception as e: - # raise e - # - # processor = XFormersCrossAttnProcessor(attention_op=attention_op) - raise NotImplementedError - else: - processor = CrossAttnProcessor() - - self.set_processor(processor) - - def set_attention_slice(self, slice_size): - if slice_size is not None and slice_size > self.sliceable_head_dim: - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}." - ) - - if slice_size is not None and self.added_kv_proj_dim is not None: - processor = SlicedAttnAddedKVProcessor(slice_size) - elif slice_size is not None: - processor = SlicedAttnProcessor(slice_size) - elif self.added_kv_proj_dim is not None: - processor = CrossAttnAddedKVProcessor() - else: - processor = CrossAttnProcessor() - - self.set_processor(processor) - - def set_processor(self, processor: "AttnProcessor"): - self.processor = processor - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - **cross_attention_kwargs, - ): - # The `CrossAttention` class can call different attention processors / attention functions - # here we simply pass along all tensors to the selected processor class - # For standard processors that are defined here, `**cross_attention_kwargs` is empty - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - def batch_to_head_dim(self, tensor): - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape( - batch_size // head_size, seq_len, dim * head_size - ) - return tensor - - def head_to_batch_dim(self, tensor): - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape( - batch_size * head_size, seq_len, dim // head_size - ) - return tensor - - def get_attention_scores(self, query, key, attention_mask=None): - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - - attention_scores = pi.baddbmm( - pi.empty( - query.shape[0], - query.shape[1], - key.shape[1], - dtype=query.dtype, - device=query.device, - ), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - - if attention_mask is not None: - attention_scores = attention_scores + attention_mask - - if self.upcast_softmax: - attention_scores = attention_scores.float() - - attention_probs = attention_scores.softmax(dim=-1) - attention_probs = attention_probs.to(dtype) - - return attention_probs - - def prepare_attention_mask(self, attention_mask, target_length): - head_size = self.heads - if attention_mask is None: - return attention_mask - - if attention_mask.shape[-1] != target_length: - if attention_mask.device.type == "mps": - # HACK: MPS: Does not support padding by greater than dimension of input tensor. - # Instead, we can manually construct the padding tensor. - padding_shape = ( - attention_mask.shape[0], - attention_mask.shape[1], - target_length, - ) - padding = pi.zeros(padding_shape, device=attention_mask.device) - attention_mask = pi.concat([attention_mask, padding], dim=2) - else: - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - attention_mask = attention_mask.repeat_interleave(head_size, dim=0) - return attention_mask - - -class CrossAttnProcessor: - def __call__( - self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = pi.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4): - super().__init__() - - if rank > min(in_features, out_features): - raise ValueError( - f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" - ) - - self.down = nn.Linear(in_features, rank, bias=False) - self.up = nn.Linear(rank, out_features, bias=False) - self.scale = 1.0 - - nn.init.normal_(self.down.weight, std=1 / rank) - nn.init.zeros_(self.up.weight) - - def forward(self, hidden_states): - orig_dtype = hidden_states.dtype - dtype = self.down.weight.dtype - - down_hidden_states = self.down(hidden_states.to(dtype)) - up_hidden_states = self.up(down_hidden_states) - - return up_hidden_states.to(orig_dtype) - - -class LoRACrossAttnProcessor(nn.Module): - def __init__(self, hidden_size, cross_attention_dim=None, rank=4): - super().__init__() - - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) - self.to_k_lora = LoRALinearLayer( - cross_attention_dim or hidden_size, hidden_size - ) - self.to_v_lora = LoRALinearLayer( - cross_attention_dim or hidden_size, hidden_size - ) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) - - def __call__( - self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - scale=1.0, - ): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) - query = attn.head_to_batch_dim(query) - - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) - - key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora( - encoder_hidden_states - ) - value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora( - encoder_hidden_states - ) - - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = pi.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora( - hidden_states - ) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class CrossAttnAddedKVProcessor: - def __call__( - self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - residual = hidden_states - hidden_states = hidden_states.view( - hidden_states.shape[0], hidden_states.shape[1], -1 - ).transpose(1, 2) - batch_size, sequence_length, _ = hidden_states.shape - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.head_to_batch_dim( - encoder_hidden_states_key_proj - ) - encoder_hidden_states_value_proj = attn.head_to_batch_dim( - encoder_hidden_states_value_proj - ) - - key = pi.concat([encoder_hidden_states_key_proj, key], dim=1) - value = pi.concat([encoder_hidden_states_value_proj, value], dim=1) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = pi.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual - - return hidden_states - - -class XFormersCrossAttnProcessor: - def __init__(self, attention_op: Optional[Callable] = None): - self.attention_op = attention_op - - def __call__( - self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = hidden_states.shape - - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask, op=self.attention_op - ) - hidden_states = hidden_states.to(query.dtype) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -class LoRAXFormersCrossAttnProcessor(nn.Module): - def __init__(self, hidden_size, cross_attention_dim, rank=4): - super().__init__() - - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) - self.to_k_lora = LoRALinearLayer( - cross_attention_dim or hidden_size, hidden_size - ) - self.to_v_lora = LoRALinearLayer( - cross_attention_dim or hidden_size, hidden_size - ) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) - - def __call__( - self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - scale=1.0, - ): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) - query = attn.head_to_batch_dim(query).contiguous() - - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) - - key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora( - encoder_hidden_states - ) - value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora( - encoder_hidden_states - ) - - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora( - hidden_states - ) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class SlicedAttnProcessor: - def __init__(self, slice_size): - self.slice_size = slice_size - - def __call__( - self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = hidden_states.shape - - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - batch_size_attention = query.shape[0] - hidden_states = pi.zeros( - (batch_size_attention, sequence_length, dim // attn.heads), - device=query.device, - dtype=query.dtype, - ) - - for i in range(hidden_states.shape[0] // self.slice_size): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size - - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = ( - attention_mask[start_idx:end_idx] - if attention_mask is not None - else None - ) - - attn_slice = attn.get_attention_scores( - query_slice, key_slice, attn_mask_slice - ) - - attn_slice = pi.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class SlicedAttnAddedKVProcessor: - def __init__(self, slice_size): - self.slice_size = slice_size - - def __call__( - self, - attn: "CrossAttention", - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - residual = hidden_states - hidden_states = hidden_states.view( - hidden_states.shape[0], hidden_states.shape[1], -1 - ).transpose(1, 2) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - - batch_size, sequence_length, _ = hidden_states.shape - - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - encoder_hidden_states_key_proj = attn.head_to_batch_dim( - encoder_hidden_states_key_proj - ) - encoder_hidden_states_value_proj = attn.head_to_batch_dim( - encoder_hidden_states_value_proj - ) - - key = pi.concat([encoder_hidden_states_key_proj, key], dim=1) - value = pi.concat([encoder_hidden_states_value_proj, value], dim=1) - - batch_size_attention = query.shape[0] - hidden_states = pi.zeros( - (batch_size_attention, sequence_length, dim // attn.heads), - device=query.device, - dtype=query.dtype, - ) - - for i in range(hidden_states.shape[0] // self.slice_size): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size - - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = ( - attention_mask[start_idx:end_idx] - if attention_mask is not None - else None - ) - - attn_slice = attn.get_attention_scores( - query_slice, key_slice, attn_mask_slice - ) - - attn_slice = pi.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual - - return hidden_states - - -AttnProcessor = Union[ - CrossAttnProcessor, - XFormersCrossAttnProcessor, - SlicedAttnProcessor, - CrossAttnAddedKVProcessor, - SlicedAttnAddedKVProcessor, -] diff --git a/pi/models/unet/embeddings.py b/pi/models/unet/embeddings.py deleted file mode 100644 index de0cbe8..0000000 --- a/pi/models/unet/embeddings.py +++ /dev/null @@ -1,385 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -import numpy as np -from ... import nn -from ... import pi -from ...nn import functional as F - - -def get_timestep_embedding( - timesteps: pi.Tensor, - embedding_dim: int, - flip_sin_to_cos: bool = False, - downscale_freq_shift: float = 1, - scale: float = 1, - max_period: int = 10000, -): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] Tensor of positional embeddings. - """ - assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" - - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * pi.arange( - start=0, end=half_dim, dtype=pi.float32, device=timesteps.device - ) - exponent = exponent / (half_dim - downscale_freq_shift) - - emb = pi.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] - - # scale embeddings - emb = scale * emb - - # concat sine and cosine embeddings - emb = pi.cat([pi.sin(emb), pi.cos(emb)], dim=-1) - - # flip sine and cosine embeddings - if flip_sin_to_cos: - emb = pi.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) - - # zero pad - if embedding_dim % 2 == 1: - emb = pi.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): - """ - grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid_h = np.arange(grid_size, dtype=np.float32) - grid_w = np.arange(grid_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size, grid_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token and extra_tokens > 0: - pos_embed = np.concatenate( - [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 - ) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) - """ - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000 ** omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - -class PatchEmbed(nn.Module): - """2D Image to Patch Embedding""" - - def __init__( - self, - height=224, - width=224, - patch_size=16, - in_channels=3, - embed_dim=768, - layer_norm=False, - flatten=True, - bias=True, - ): - super().__init__() - - num_patches = (height // patch_size) * (width // patch_size) - self.flatten = flatten - self.layer_norm = layer_norm - - self.proj = nn.Conv2d( - in_channels, - embed_dim, - kernel_size=(patch_size, patch_size), - stride=patch_size, - bias=bias, - ) - if layer_norm: - self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) - else: - self.norm = None - - pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches ** 0.5)) - self.register_buffer( - "pos_embed", pi.from_numpy(pos_embed).float().unsqueeze(0), persistent=False - ) - - def forward(self, latent): - latent = self.proj(latent) - if self.flatten: - latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC - if self.layer_norm: - latent = self.norm(latent) - return latent + self.pos_embed - - -class TimestepEmbedding(nn.Module): - def __init__( - self, - in_channels: int, - time_embed_dim: int, - act_fn: str = "silu", - out_dim: int = None, - ): - super().__init__() - - self.linear_1 = nn.Linear(in_channels, time_embed_dim) - self.act = None - if act_fn == "silu": - self.act = nn.SiLU() - elif act_fn == "mish": - self.act = nn.Mish() - - if out_dim is not None: - time_embed_dim_out = out_dim - else: - time_embed_dim_out = time_embed_dim - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) - - def forward(self, sample): - sample = self.linear_1(sample) - - if self.act is not None: - sample = self.act(sample) - - sample = self.linear_2(sample) - return sample - - -class Timesteps(nn.Module): - def __init__( - self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float - ): - super().__init__() - self.num_channels = num_channels - self.flip_sin_to_cos = flip_sin_to_cos - self.downscale_freq_shift = downscale_freq_shift - - def forward(self, timesteps): - t_emb = get_timestep_embedding( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - ) - return t_emb - - -class GaussianFourierProjection(nn.Module): - """Gaussian Fourier embeddings for noise levels.""" - - def __init__( - self, - embedding_size: int = 256, - scale: float = 1.0, - set_W_to_weight=True, - log=True, - flip_sin_to_cos=False, - ): - super().__init__() - self.weight = nn.Parameter( - pi.randn(embedding_size) * scale, requires_grad=False - ) - self.log = log - self.flip_sin_to_cos = flip_sin_to_cos - - if set_W_to_weight: - # to delete later - self.W = nn.Parameter(pi.randn(embedding_size) * scale, requires_grad=False) - - self.weight = self.W - - def forward(self, x): - if self.log: - x = pi.log(x) - - x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi - - if self.flip_sin_to_cos: - out = pi.cat([pi.cos(x_proj), pi.sin(x_proj)], dim=-1) - else: - out = pi.cat([pi.sin(x_proj), pi.cos(x_proj)], dim=-1) - return out - - -class ImagePositionalEmbeddings(nn.Module): - """ - Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the - height and width of the latent space. - - For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 - - For VQ-diffusion: - - Output vector embeddings are used as input for the transformer. - - Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. - - Args: - num_embed (`int`): - Number of embeddings for the latent pixels embeddings. - height (`int`): - Height of the latent image i.e. the number of height embeddings. - width (`int`): - Width of the latent image i.e. the number of width embeddings. - embed_dim (`int`): - Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. - """ - - def __init__( - self, - num_embed: int, - height: int, - width: int, - embed_dim: int, - ): - super().__init__() - - self.height = height - self.width = width - self.num_embed = num_embed - self.embed_dim = embed_dim - - self.emb = nn.Embedding(self.num_embed, embed_dim) - self.height_emb = nn.Embedding(self.height, embed_dim) - self.width_emb = nn.Embedding(self.width, embed_dim) - - def forward(self, index): - emb = self.emb(index) - - height_emb = self.height_emb( - pi.arange(self.height, device=index.device).view(1, self.height) - ) - - # 1 x H x D -> 1 x H x 1 x D - height_emb = height_emb.unsqueeze(2) - - width_emb = self.width_emb( - pi.arange(self.width, device=index.device).view(1, self.width) - ) - - # 1 x W x D -> 1 x 1 x W x D - width_emb = width_emb.unsqueeze(1) - - pos_emb = height_emb + width_emb - - # 1 x H x W x D -> 1 x L xD - pos_emb = pos_emb.view(1, self.height * self.width, -1) - - emb = emb + pos_emb[:, : emb.shape[1], :] - - return emb - - -class LabelEmbedding(nn.Module): - """ - Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. - - Args: - num_classes (`int`): The number of classes. - hidden_size (`int`): The size of the vector embeddings. - dropout_prob (`float`): The probability of dropping a label. - """ - - def __init__(self, num_classes, hidden_size, dropout_prob): - super().__init__() - use_cfg_embedding = dropout_prob > 0 - self.embedding_table = nn.Embedding( - num_classes + use_cfg_embedding, hidden_size - ) - self.num_classes = num_classes - self.dropout_prob = dropout_prob - - def token_drop(self, labels, force_drop_ids=None): - """ - Drops labels to enable classifier-free guidance. - """ - if force_drop_ids is None: - drop_ids = ( - pi.rand(labels.shape[0], device=labels.device) < self.dropout_prob - ) - else: - drop_ids = pi.tensor(force_drop_ids == 1) - labels = pi.where(drop_ids, self.num_classes, labels) - return labels - - def forward(self, labels, force_drop_ids=None): - use_dropout = self.dropout_prob > 0 - if (self.training and use_dropout) or (force_drop_ids is not None): - labels = self.token_drop(labels, force_drop_ids) - embeddings = self.embedding_table(labels) - return embeddings - - -class CombinedTimestepLabelEmbeddings(nn.Module): - def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): - super().__init__() - - self.time_proj = Timesteps( - num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1 - ) - self.timestep_embedder = TimestepEmbedding( - in_channels=256, time_embed_dim=embedding_dim - ) - self.class_embedder = LabelEmbedding( - num_classes, embedding_dim, class_dropout_prob - ) - - def forward(self, timestep, class_labels, hidden_dtype=None): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder( - timesteps_proj.to(dtype=hidden_dtype) - ) # (N, D) - - class_labels = self.class_embedder(class_labels) # (N, D) - - conditioning = timesteps_emb + class_labels # (N, D) - - return conditioning diff --git a/pi/models/unet/resnet.py b/pi/models/unet/resnet.py deleted file mode 100644 index ec95523..0000000 --- a/pi/models/unet/resnet.py +++ /dev/null @@ -1,752 +0,0 @@ -from functools import partial - -from ... import nn -from ... import pi -from ...nn import functional as F - - -class Upsample1D(nn.Module): - """ - An upsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - use_conv_transpose: - out_channels: - """ - - def __init__( - self, - channels, - use_conv=False, - use_conv_transpose=False, - out_channels=None, - name="conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - - self.conv = None - if use_conv_transpose: - self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.use_conv_transpose: - return self.conv(x) - - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - - if self.use_conv: - x = self.conv(x) - - return x - - -class Downsample1D(nn.Module): - """ - A downsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - out_channels: - padding: - """ - - def __init__( - self, channels, use_conv=False, out_channels=None, padding=1, name="conv" - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - - if use_conv: - self.conv = nn.Conv1d( - self.channels, self.out_channels, 3, stride=stride, padding=padding - ) - else: - assert self.channels == self.out_channels - self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.conv(x) - - -class Upsample2D(nn.Module): - """ - An upsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - use_conv_transpose: - out_channels: - """ - - def __init__( - self, - channels, - use_conv=False, - use_conv_transpose=False, - out_channels=None, - name="conv", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - - conv = None - if use_conv_transpose: - conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if name == "conv": - self.conv = conv - else: - self.Conv2d_0 = conv - - def forward(self, hidden_states, output_size=None): - assert hidden_states.shape[1] == self.channels - - if self.use_conv_transpose: - return self.conv(hidden_states) - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch - # https://github.com/pytorch/pytorch/issues/86679 - dtype = hidden_states.dtype - if dtype == pi.bfloat16: - hidden_states = hidden_states.to(pi.float32) - - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - - # if `output_size` is passed we force the interpolation output - # size and do not make use of `scale_factor=2` - if output_size is None: - hidden_states = F.interpolate( - hidden_states, scale_factor=2.0, mode="nearest" - ) - else: - hidden_states = F.interpolate( - hidden_states, size=output_size, mode="nearest" - ) - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == pi.bfloat16: - hidden_states = hidden_states.to(dtype) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if self.use_conv: - if self.name == "conv": - hidden_states = self.conv(hidden_states) - else: - hidden_states = self.Conv2d_0(hidden_states) - - return hidden_states - - -class Downsample2D(nn.Module): - """ - A downsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - out_channels: - padding: - """ - - def __init__( - self, channels, use_conv=False, out_channels=None, padding=1, name="conv" - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - - if use_conv: - conv = nn.Conv2d( - self.channels, self.out_channels, 3, stride=stride, padding=padding - ) - else: - assert self.channels == self.out_channels - conv = nn.AvgPool2d(kernel_size=stride, stride=stride) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if name == "conv": - self.Conv2d_0 = conv - self.conv = conv - elif name == "Conv2d_0": - self.conv = conv - else: - self.conv = conv - - def forward(self, hidden_states): - assert hidden_states.shape[1] == self.channels - if self.use_conv and self.padding == 0: - pad = (0, 1, 0, 1) - hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) - - assert hidden_states.shape[1] == self.channels - hidden_states = self.conv(hidden_states) - - return hidden_states - - -class FirUpsample2D(nn.Module): - def __init__( - self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1) - ): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = nn.Conv2d( - channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - self.use_conv = use_conv - self.fir_kernel = fir_kernel - self.out_channels = out_channels - - def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): - """Fused `upsample_2d()` followed by `Conv2d()`. - - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of - arbitrary order. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight: Weight tensor of the shape `[filterH, filterW, inChannels, - outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same - datatype as `hidden_states`. - """ - - assert isinstance(factor, int) and factor >= 1 - - # Setup filter kernel. - if kernel is None: - kernel = [1] * factor - - # setup kernel - kernel = pi.tensor(kernel, dtype=pi.float32) - if kernel.ndim == 1: - kernel = pi.outer(kernel, kernel) - kernel /= pi.sum(kernel) - - kernel = kernel * (gain * (factor ** 2)) - - if self.use_conv: - convH = weight.shape[2] - convW = weight.shape[3] - inC = weight.shape[1] - - pad_value = (kernel.shape[0] - factor) - (convW - 1) - - stride = (factor, factor) - # Determine data dimensions. - output_shape = ( - (hidden_states.shape[2] - 1) * factor + convH, - (hidden_states.shape[3] - 1) * factor + convW, - ) - output_padding = ( - output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, - output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, - ) - assert output_padding[0] >= 0 and output_padding[1] >= 0 - num_groups = hidden_states.shape[1] // inC - - # Transpose weights. - weight = pi.reshape(weight, (num_groups, -1, inC, convH, convW)) - weight = pi.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) - weight = pi.reshape(weight, (num_groups * inC, -1, convH, convW)) - - inverse_conv = F.conv_transpose2d( - hidden_states, - weight, - stride=stride, - output_padding=output_padding, - padding=0, - ) - - output = upfirdn2d_native( - inverse_conv, - pi.tensor(kernel, device=inverse_conv.device), - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), - ) - else: - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - pi.tensor(kernel, device=hidden_states.device), - up=factor, - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), - ) - - return output - - def forward(self, hidden_states): - if self.use_conv: - height = self._upsample_2d( - hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel - ) - height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) - - return height - - -class FirDownsample2D(nn.Module): - def __init__( - self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1) - ): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = nn.Conv2d( - channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - self.fir_kernel = fir_kernel - self.use_conv = use_conv - self.out_channels = out_channels - - def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): - """Fused `Conv2d()` followed by `downsample_2d()`. - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of - arbitrary order. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight: - Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be - performed by `inChannels = x.shape[0] // numGroups`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * - factor`, which corresponds to average pooling. - factor: Integer downsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and - same datatype as `x`. - """ - - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - # setup kernel - kernel = pi.tensor(kernel, dtype=pi.float32) - if kernel.ndim == 1: - kernel = pi.outer(kernel, kernel) - kernel /= pi.sum(kernel) - - kernel = kernel * gain - - if self.use_conv: - _, _, convH, convW = weight.shape - pad_value = (kernel.shape[0] - factor) + (convW - 1) - stride_value = [factor, factor] - upfirdn_input = upfirdn2d_native( - hidden_states, - pi.tensor(kernel, device=hidden_states.device), - pad=((pad_value + 1) // 2, pad_value // 2), - ) - output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) - else: - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - pi.tensor(kernel, device=hidden_states.device), - down=factor, - pad=((pad_value + 1) // 2, pad_value // 2), - ) - - return output - - def forward(self, hidden_states): - if self.use_conv: - downsample_input = self._downsample_2d( - hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel - ) - hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - hidden_states = self._downsample_2d( - hidden_states, kernel=self.fir_kernel, factor=2 - ) - - return hidden_states - - -class ResnetBlock2D(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout=0.0, - temb_channels=512, - groups=32, - groups_out=None, - pre_norm=True, - eps=1e-6, - non_linearity="swish", - time_embedding_norm="default", - kernel=None, - output_scale_factor=1.0, - use_in_shortcut=None, - up=False, - down=False, - ): - super().__init__() - self.pre_norm = pre_norm - self.pre_norm = True - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.time_embedding_norm = time_embedding_norm - self.up = up - self.down = down - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - self.norm1 = pi.nn.GroupNorm( - num_groups=groups, num_channels=in_channels, eps=eps, affine=True - ) - - self.conv1 = pi.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - - if temb_channels is not None: - if self.time_embedding_norm == "default": - time_emb_proj_out_channels = out_channels - elif self.time_embedding_norm == "scale_shift": - time_emb_proj_out_channels = out_channels * 2 - else: - raise ValueError( - f"unknown time_embedding_norm : {self.time_embedding_norm} " - ) - - self.time_emb_proj = pi.nn.Linear(temb_channels, time_emb_proj_out_channels) - else: - self.time_emb_proj = None - - self.norm2 = pi.nn.GroupNorm( - num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True - ) - self.dropout = pi.nn.Dropout(dropout) - self.conv2 = pi.nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - - self.upsample = self.downsample = None - if self.up: - if kernel == "fir": - fir_kernel = (1, 3, 3, 1) - self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) - elif kernel == "sde_vp": - self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") - else: - self.upsample = Upsample2D(in_channels, use_conv=False) - elif self.down: - if kernel == "fir": - fir_kernel = (1, 3, 3, 1) - self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) - elif kernel == "sde_vp": - self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) - else: - self.downsample = Downsample2D( - in_channels, use_conv=False, padding=1, name="op" - ) - - self.use_in_shortcut = ( - self.in_channels != self.out_channels - if use_in_shortcut is None - else use_in_shortcut - ) - - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = pi.nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, input_tensor, temb): - hidden_states = input_tensor - - hidden_states = self.norm1(hidden_states) - hidden_states = self.nonlinearity(hidden_states) - - if self.upsample is not None: - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - input_tensor = input_tensor.contiguous() - hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor) - hidden_states = self.upsample(hidden_states) - elif self.downsample is not None: - input_tensor = self.downsample(input_tensor) - hidden_states = self.downsample(hidden_states) - - hidden_states = self.conv1(hidden_states) - - if temb is not None: - temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] - - if temb is not None and self.time_embedding_norm == "default": - hidden_states = hidden_states + temb - - hidden_states = self.norm2(hidden_states) - - if temb is not None and self.time_embedding_norm == "scale_shift": - scale, shift = pi.chunk(temb, 2, dim=1) - hidden_states = hidden_states * (1 + scale) + shift - - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - - return output_tensor - - -class Mish(pi.nn.Module): - def forward(self, hidden_states): - return hidden_states * pi.tanh(pi.nn.functional.softplus(hidden_states)) - - -# unet_rl.py -def rearrange_dims(tensor): - if len(tensor.shape) == 2: - return tensor[:, :, None] - if len(tensor.shape) == 3: - return tensor[:, :, None, :] - elif len(tensor.shape) == 4: - return tensor[:, :, 0, :] - else: - raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") - - -class Conv1dBlock(nn.Module): - """ - Conv1d --> GroupNorm --> Mish - """ - - def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): - super().__init__() - - self.conv1d = nn.Conv1d( - inp_channels, out_channels, kernel_size, padding=kernel_size // 2 - ) - self.group_norm = nn.GroupNorm(n_groups, out_channels) - self.mish = nn.Mish() - - def forward(self, x): - x = self.conv1d(x) - x = rearrange_dims(x) - x = self.group_norm(x) - x = rearrange_dims(x) - x = self.mish(x) - return x - - -# unet_rl.py -class ResidualTemporalBlock1D(nn.Module): - def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): - super().__init__() - self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) - self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) - - self.time_emb_act = nn.Mish() - self.time_emb = nn.Linear(embed_dim, out_channels) - - self.residual_conv = ( - nn.Conv1d(inp_channels, out_channels, 1) - if inp_channels != out_channels - else nn.Identity() - ) - - def forward(self, x, t): - """ - Args: - x : [ batch_size x inp_channels x horizon ] - t : [ batch_size x embed_dim ] - - returns: - out : [ batch_size x out_channels x horizon ] - """ - t = self.time_emb_act(t) - t = self.time_emb(t) - out = self.conv_in(x) + rearrange_dims(t) - out = self.conv_out(out) - return out + self.residual_conv(x) - - -def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): - r"""Upsample2D a batch of 2D images with the given filter. - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given - filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified - `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is - a: multiple of the upsampling factor. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H * factor, W * factor]` - """ - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - kernel = pi.tensor(kernel, dtype=pi.float32) - if kernel.ndim == 1: - kernel = pi.outer(kernel, kernel) - kernel /= pi.sum(kernel) - - kernel = kernel * (gain * (factor ** 2)) - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - kernel.to(device=hidden_states.device), - up=factor, - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), - ) - return output - - -def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): - r"""Downsample2D a batch of 2D images with the given filter. - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the - given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the - specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its - shape is a multiple of the downsampling factor. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to average pooling. - factor: Integer downsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H // factor, W // factor]` - """ - - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - kernel = pi.tensor(kernel, dtype=pi.float32) - if kernel.ndim == 1: - kernel = pi.outer(kernel, kernel) - kernel /= pi.sum(kernel) - - kernel = kernel * gain - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - kernel.to(device=hidden_states.device), - down=factor, - pad=((pad_value + 1) // 2, pad_value // 2), - ) - return output - - -def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): - up_x = up_y = up - down_x = down_y = down - pad_x0 = pad_y0 = pad[0] - pad_x1 = pad_y1 = pad[1] - - _, channel, in_h, in_w = tensor.shape - tensor = tensor.reshape(-1, in_h, in_w, 1) - - _, in_h, in_w, minor = tensor.shape - kernel_h, kernel_w = kernel.shape - - out = tensor.view(-1, in_h, 1, in_w, 1, minor) - out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) - out = out.view(-1, in_h * up_y, in_w * up_x, minor) - - out = F.pad( - out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] - ) - out = out.to(tensor.device) # Move back to mps if necessary - out = out[ - :, - max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), - max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), - :, - ] - - out = out.permute(0, 3, 1, 2) - out = out.reshape( - [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] - ) - w = pi.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) - out = F.conv2d(out, w) - out = out.reshape( - -1, - minor, - in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, - in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, - ) - out = out.permute(0, 2, 3, 1) - out = out[:, ::down_y, ::down_x, :] - - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 - - return out.view(-1, channel, out_h, out_w) diff --git a/pi/models/unet/transformer_2d.py b/pi/models/unet/transformer_2d.py deleted file mode 100644 index 5376875..0000000 --- a/pi/models/unet/transformer_2d.py +++ /dev/null @@ -1,357 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Optional - -from ... import nn -from ... import pi -from ...nn import functional as F - -from .attention import BasicTransformerBlock -from .embeddings import PatchEmbed, ImagePositionalEmbeddings - - -@dataclass -class Transformer2DModelOutput: - """ - Args: - sample (`pi.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): - Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions - for the unnoised latent pixels. - """ - - sample: pi.FloatTensor - - -class Transformer2DModel(nn.Module): - """ - Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual - embeddings) inputs. - - When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard - transformer action. Finally, reshape to image. - - When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional - embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict - classes of unnoised image. - - Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised - image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - Pass if the input is continuous. The number of channels in the input and output. - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. - sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. - Note that this is fixed at training time as it is used for learning a number of position embeddings. See - `ImagePositionalEmbeddings`. - num_vector_embeds (`int`, *optional*): - Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. - The number of diffusion steps used during training. Note that this is fixed at training time as it is used - to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for - up to but not more than steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the TransformerBlocks' attention should contain a bias parameter. - """ - - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - num_vector_embeds: Optional[int] = None, - patch_size: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - norm_type: str = "layer_norm", - norm_elementwise_affine: bool = True, - ): - super().__init__() - self.use_linear_projection = use_linear_projection - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - - # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` - # Define whether input is continuous or discrete depending on configuration - self.is_input_continuous = (in_channels is not None) and (patch_size is None) - self.is_input_vectorized = num_vector_embeds is not None - self.is_input_patches = in_channels is not None and patch_size is not None - - if norm_type == "layer_norm" and num_embeds_ada_norm is not None: - deprecation_message = ( - f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" - " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." - " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" - " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" - " would be very nice if you could open a Pull request for the `transformer/config.json` file" - ) - norm_type = "ada_norm" - - if self.is_input_continuous and self.is_input_vectorized: - raise ValueError( - f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" - " sure that either `in_channels` or `num_vector_embeds` is None." - ) - elif self.is_input_vectorized and self.is_input_patches: - raise ValueError( - f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" - " sure that either `num_vector_embeds` or `num_patches` is None." - ) - elif ( - not self.is_input_continuous - and not self.is_input_vectorized - and not self.is_input_patches - ): - raise ValueError( - f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" - f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." - ) - - # 2. Define input layers - if self.is_input_continuous: - self.in_channels = in_channels - - self.norm = pi.nn.GroupNorm( - num_groups=norm_num_groups, - num_channels=in_channels, - eps=1e-6, - affine=True, - ) - if use_linear_projection: - self.proj_in = nn.Linear(in_channels, inner_dim) - else: - self.proj_in = nn.Conv2d( - in_channels, inner_dim, kernel_size=1, stride=1, padding=0 - ) - elif self.is_input_vectorized: - assert ( - sample_size is not None - ), "Transformer2DModel over discrete input must provide sample_size" - assert ( - num_vector_embeds is not None - ), "Transformer2DModel over discrete input must provide num_embed" - - self.height = sample_size - self.width = sample_size - self.num_vector_embeds = num_vector_embeds - self.num_latent_pixels = self.height * self.width - - self.latent_image_embedding = ImagePositionalEmbeddings( - num_embed=num_vector_embeds, - embed_dim=inner_dim, - height=self.height, - width=self.width, - ) - elif self.is_input_patches: - assert ( - sample_size is not None - ), "Transformer2DModel over patched input must provide sample_size" - - self.height = sample_size - self.width = sample_size - - self.patch_size = patch_size - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - ) - - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - ) - for d in range(num_layers) - ] - ) - - # 4. Define output layers - self.out_channels = in_channels if out_channels is None else out_channels - if self.is_input_continuous: - # TODO: should use out_channels for continous projections - if use_linear_projection: - self.proj_out = nn.Linear(in_channels, inner_dim) - else: - self.proj_out = nn.Conv2d( - inner_dim, in_channels, kernel_size=1, stride=1, padding=0 - ) - elif self.is_input_vectorized: - self.norm_out = nn.LayerNorm(inner_dim) - self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - elif self.is_input_patches: - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) - self.proj_out_2 = nn.Linear( - inner_dim, patch_size * patch_size * self.out_channels - ) - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - cross_attention_kwargs=None, - return_dict: bool = True, - ): - """ - Args: - hidden_states ( When discrete, `pi.LongTensor` of shape `(batch size, num latent pixels)`. - When continous, `pi.FloatTensor` of shape `(batch size, channel, height, width)`): Input - hidden_states - encoder_hidden_states ( `pi.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `pi.long`, *optional*): - Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. - class_labels ( `pi.LongTensor` of shape `(batch size, num classes)`, *optional*): - Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels - conditioning. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: - [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - # 1. Input - if self.is_input_continuous: - batch, _, height, width = hidden_states.shape - residual = hidden_states - - hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( - batch, height * width, inner_dim - ) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( - batch, height * width, inner_dim - ) - hidden_states = self.proj_in(hidden_states) - elif self.is_input_vectorized: - hidden_states = self.latent_image_embedding(hidden_states) - elif self.is_input_patches: - hidden_states = self.pos_embed(hidden_states) - - # 2. Blocks - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - - # 3. Output - if self.is_input_continuous: - if not self.use_linear_projection: - hidden_states = ( - hidden_states.reshape(batch, height, width, inner_dim) - .permute(0, 3, 1, 2) - .contiguous() - ) - hidden_states = self.proj_out(hidden_states) - else: - hidden_states = self.proj_out(hidden_states) - hidden_states = ( - hidden_states.reshape(batch, height, width, inner_dim) - .permute(0, 3, 1, 2) - .contiguous() - ) - - output = hidden_states + residual - elif self.is_input_vectorized: - hidden_states = self.norm_out(hidden_states) - logits = self.out(hidden_states) - # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) - logits = logits.permute(0, 2, 1) - - # log(p(x_0)) - output = F.log_softmax(logits.double(), dim=1).float() - elif self.is_input_patches: - # TODO: cleanup! - conditioning = self.transformer_blocks[0].norm1.emb( - timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = ( - self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - ) - hidden_states = self.proj_out_2(hidden_states) - - # unpatchify - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=( - -1, - height, - width, - self.patch_size, - self.patch_size, - self.out_channels, - ) - ) - hidden_states = pi.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=( - -1, - self.out_channels, - height * self.patch_size, - width * self.patch_size, - ) - ) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) diff --git a/pi/models/unet/unet_2d_blocks.py b/pi/models/unet/unet_2d_blocks.py deleted file mode 100644 index 75b51b5..0000000 --- a/pi/models/unet/unet_2d_blocks.py +++ /dev/null @@ -1,2312 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np -from diffusers import Transformer2DModel - -from .attention import AttentionBlock -from .resnet import ResnetBlock2D -from ... import nn, Tensor -from ... import pi - - -def get_down_block( - down_block_type, - num_layers, - in_channels, - out_channels, - temb_channels, - add_downsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - resnet_groups=None, - cross_attention_dim=None, - downsample_padding=None, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", -): - down_block_type = ( - down_block_type[7:] - if down_block_type.startswith("UNetRes") - else down_block_type - ) - if down_block_type == "DownBlock2D": - return DownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "ResnetDownsampleBlock2D": - return ResnetDownsampleBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnDownBlock2D": - return AttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "CrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for CrossAttnDownBlock2D" - ) - return CrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "SimpleCrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D" - ) - return SimpleCrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "SkipDownBlock2D": - return SkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnSkipDownBlock2D": - return AttnSkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "DownEncoderBlock2D": - return DownEncoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnDownEncoderBlock2D": - return AttnDownEncoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - raise ValueError(f"{down_block_type} does not exist.") - - -def get_up_block( - up_block_type, - num_layers, - in_channels, - out_channels, - prev_output_channel, - temb_channels, - add_upsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - resnet_groups=None, - cross_attention_dim=None, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", -): - up_block_type = ( - up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type - ) - if up_block_type == "UpBlock2D": - return UpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "ResnetUpsampleBlock2D": - return ResnetUpsampleBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "CrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for CrossAttnUpBlock2D" - ) - return CrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "SimpleCrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D" - ) - return SimpleCrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "AttnUpBlock2D": - return AttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "SkipUpBlock2D": - return SkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "AttnSkipUpBlock2D": - return AttnSkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "UpDecoderBlock2D": - return UpDecoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "AttnUpDecoderBlock2D": - return AttnUpDecoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - raise ValueError(f"{up_block_type} does not exist.") - - -class UNetMidBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - add_attention: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - ): - super().__init__() - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - self.add_attention = add_attention - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - if self.add_attention: - attentions.append( - AttentionBlock( - in_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - else: - attentions.append(None) - - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = attn(hidden_states) - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class UNetMidBlock2DCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - dual_cross_attention=False, - use_linear_projection=False, - upcast_attention=False, - ): - super().__init__() - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class UNetMidBlock2DSimpleCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - ): - super().__init__() - - self.has_cross_attention = True - - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) - - self.num_heads = in_channels // self.attn_num_head_channels - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - attentions.append( - CrossAttention( - query_dim=in_channels, - cross_attention_dim=in_channels, - heads=self.num_heads, - dim_head=attn_num_head_channels, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - # resnet - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class AttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class CrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - # TODO(Patrick, William) - attention mask is not used - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class DownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class DownEncoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states): - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states - - -class AttnDownEncoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states): - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states - - -class AttnSkipDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=np.sqrt(2.0), - downsample_padding=1, - add_downsample=True, - ): - super().__init__() - self.attentions = nn.ModuleList([]) - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - self.resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(in_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - ) - ) - - if add_downsample: - self.resnet_down = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - down=True, - kernel="fir", - ) - self.downsamplers = nn.ModuleList( - [FirDownsample2D(out_channels, out_channels=out_channels)] - ) - self.skip_conv = nn.Conv2d( - 3, out_channels, kernel_size=(1, 1), stride=(1, 1) - ) - else: - self.resnet_down = None - self.downsamplers = None - self.skip_conv = None - - def forward(self, hidden_states, temb=None, skip_sample=None): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states += (hidden_states,) - - if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb) - for downsampler in self.downsamplers: - skip_sample = downsampler(skip_sample) - - hidden_states = self.skip_conv(skip_sample) + hidden_states - - output_states += (hidden_states,) - - return hidden_states, output_states, skip_sample - - -class SkipDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - output_scale_factor=np.sqrt(2.0), - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - self.resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(in_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - if add_downsample: - self.resnet_down = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - down=True, - kernel="fir", - ) - self.downsamplers = nn.ModuleList( - [FirDownsample2D(out_channels, out_channels=out_channels)] - ) - self.skip_conv = nn.Conv2d( - 3, out_channels, kernel_size=(1, 1), stride=(1, 1) - ) - else: - self.resnet_down = None - self.downsamplers = None - self.skip_conv = None - - def forward(self, hidden_states, temb=None, skip_sample=None): - output_states = () - - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) - - if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb) - for downsampler in self.downsamplers: - skip_sample = downsampler(skip_sample) - - hidden_states = self.skip_conv(skip_sample) + hidden_states - - output_states += (hidden_states,) - - return hidden_states, output_states, skip_sample - - -class ResnetDownsampleBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - down=True, - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class SimpleCrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_downsample=True, - ): - super().__init__() - - self.has_cross_attention = True - - resnets = [] - attentions = [] - - self.attn_num_head_channels = attn_num_head_channels - self.num_heads = out_channels // self.attn_num_head_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - CrossAttention( - query_dim=out_channels, - cross_attention_dim=out_channels, - heads=self.num_heads, - dim_head=attn_num_head_channels, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - down=True, - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - output_states = () - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - - for resnet, attn in zip(self.resnets, self.attentions): - # resnet - hidden_states = resnet(hidden_states, temb) - - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class AttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class CrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, - ): - # TODO(Patrick, William) - attention mask is not used - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -class UpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None - ): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -class UpDecoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - def forward(self, hidden_states): - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class AttnUpDecoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] - ) - else: - self.upsamplers = None - - def forward(self, hidden_states): - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class AttnSkipUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=np.sqrt(2.0), - upsample_padding=1, - add_upsample=True, - ): - super().__init__() - self.attentions = nn.ModuleList([]) - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - self.resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(resnet_in_channels + res_skip_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - ) - ) - - self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) - if add_upsample: - self.resnet_up = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - up=True, - kernel="fir", - ) - self.skip_conv = nn.Conv2d( - out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) - ) - self.skip_norm = pi.nn.GroupNorm( - num_groups=min(out_channels // 4, 32), - num_channels=out_channels, - eps=resnet_eps, - affine=True, - ) - self.act = nn.SiLU() - else: - self.resnet_up = None - self.skip_conv = None - self.skip_norm = None - self.act = None - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None - ): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - hidden_states = self.attentions[0](hidden_states) - - if skip_sample is not None: - skip_sample = self.upsampler(skip_sample) - else: - skip_sample = 0 - - if self.resnet_up is not None: - skip_sample_states = self.skip_norm(hidden_states) - skip_sample_states = self.act(skip_sample_states) - skip_sample_states = self.skip_conv(skip_sample_states) - - skip_sample = skip_sample + skip_sample_states - - hidden_states = self.resnet_up(hidden_states, temb) - - return hidden_states, skip_sample - - -class SkipUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - output_scale_factor=np.sqrt(2.0), - add_upsample=True, - upsample_padding=1, - ): - super().__init__() - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - self.resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min((resnet_in_channels + res_skip_channels) // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) - if add_upsample: - self.resnet_up = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - up=True, - kernel="fir", - ) - self.skip_conv = nn.Conv2d( - out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) - ) - self.skip_norm = pi.nn.GroupNorm( - num_groups=min(out_channels // 4, 32), - num_channels=out_channels, - eps=resnet_eps, - affine=True, - ) - self.act = nn.SiLU() - else: - self.resnet_up = None - self.skip_conv = None - self.skip_norm = None - self.act = None - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None - ): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - if skip_sample is not None: - skip_sample = self.upsampler(skip_sample) - else: - skip_sample = 0 - - if self.resnet_up is not None: - skip_sample_states = self.skip_norm(hidden_states) - skip_sample_states = self.act(skip_sample_states) - skip_sample_states = self.skip_conv(skip_sample_states) - - skip_sample = skip_sample + skip_sample_states - - hidden_states = self.resnet_up(hidden_states, temb) - - return hidden_states, skip_sample - - -class ResnetUpsampleBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - up=True, - ) - ] - ) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None - ): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = pi.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - else: - hidden_states = resnet(hidden_states, temb) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb) - - return hidden_states - - -class SimpleCrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - self.num_heads = out_channels // self.attn_num_head_channels - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - CrossAttention( - query_dim=out_channels, - cross_attention_dim=out_channels, - heads=self.num_heads, - dim_head=attn_num_head_channels, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - up=True, - ) - ] - ) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - cross_attention_kwargs = ( - cross_attention_kwargs if cross_attention_kwargs is not None else {} - ) - for resnet, attn in zip(self.resnets, self.attentions): - # resnet - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = pi.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb) - - return hidden_states diff --git a/pi/models/unet/unet_2d_condition.py b/pi/models/unet/unet_2d_condition.py deleted file mode 100644 index e5df4d1..0000000 --- a/pi/models/unet/unet_2d_condition.py +++ /dev/null @@ -1,517 +0,0 @@ -import logging -import math -from dataclasses import dataclass -from typing import Optional, Tuple, Union, Dict, List, Any - -from diffusers.models.unet_2d_blocks import get_down_block - -from .cross_attention import AttnProcessor -from .embeddings import Timesteps, TimestepEmbedding -from .unet_2d_blocks import ( - UNetMidBlock2DCrossAttn, - UNetMidBlock2DSimpleCrossAttn, - get_up_block, - CrossAttnDownBlock2D, - DownBlock2D, - CrossAttnUpBlock2D, - UpBlock2D, -) -from ... import nn, Tensor -from ... import pi - - -logger = logging.getLogger(__name__) - - -@dataclass -class UNet2DConditionOutput: - """ - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: pi.FloatTensor - - -class UNet2DConditionModel: - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - mid_block_type: str = "UNetMidBlock2DCrossAttn", - up_block_types: Tuple[str] = ( - "UpBlock2D", - "CrossAttnUpBlock2D", - "CrossAttnUpBlock2D", - "CrossAttnUpBlock2D", - ), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: int = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, - attention_head_dim: Union[int, Tuple[int]] = 8, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - ): - super().__init__() - - self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 - - # input - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1) - ) - - # time - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - else: - self.class_embedding = None - - self.down_blocks = nn.ModuleList([]) - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[i], - downsample_padding=downsample_padding, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.down_blocks.append(down_block) - - # mid - if mid_block_type == "UNetMidBlock2DCrossAttn": - self.mid_block = UNetMidBlock2DCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": - self.mid_block = UNetMidBlock2DSimpleCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - else: - raise ValueError(f"unknown mid_block_type : {mid_block_type}") - - # count how many layers upsample the images - self.num_upsamplers = 0 - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_attention_head_dim = list(reversed(attention_head_dim)) - only_cross_attention = list(reversed(only_cross_attention)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 - - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[ - min(i + 1, len(block_out_channels) - 1) - ] - - # add upsample block for all BUT final layer - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=reversed_attention_head_dim[i], - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps - ) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d( - block_out_channels[0], out_channels, kernel_size=3, padding=1 - ) - - @property - def attn_processors(self) -> Dict[str, AttnProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors( - name: str, module: pi.nn.Module, processors: Dict[str, AttnProcessor] - ): - if hasattr(module, "set_processor"): - processors[f"{name}.processor"] = module.processor - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - def set_attn_processor( - self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]] - ): - r""" - Parameters: - `processor (`dict` of `AttnProcessor` or `AttnProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `CrossAttention` layers. - In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: pi.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def set_attention_slice(self, slice_size): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_slicable_dims(module: pi.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_slicable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_slicable_dims(module) - - num_slicable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_slicable_layers * [1] - - slice_size = ( - num_slicable_layers * [slice_size] - if not isinstance(slice_size, list) - else slice_size - ) - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice( - module: pi.nn.Module, slice_size: List[int] - ): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance( - module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D) - ): - module.gradient_checkpointing = value - - def forward( - self, - sample: pi.FloatTensor, - timestep: Union[pi.Tensor, float, int], - encoder_hidden_states: pi.Tensor, - class_labels: Optional[pi.Tensor] = None, - attention_mask: Optional[pi.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - r""" - Args: - sample (`pi.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`pi.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`pi.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2 ** self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None - - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - - # prepare attention_mask - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not pi.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = pi.float32 if is_mps else pi.float64 - else: - dtype = pi.int32 if is_mps else pi.int64 - timesteps = pi.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError( - "class_labels should be provided when num_class_embeds > 0" - ) - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb - - # 2. pre-process - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if ( - hasattr(downsample_block, "has_cross_attention") - and downsample_block.has_cross_attention - ): - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - - # 5. up - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[ - : -len(upsample_block.resnets) - ] - - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if ( - hasattr(upsample_block, "has_cross_attention") - and upsample_block.has_cross_attention - ): - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - upsample_size=upsample_size, - attention_mask=attention_mask, - ) - else: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - upsample_size=upsample_size, - ) - # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) diff --git a/pyproject.toml b/pyproject.toml index 37876fc..314893f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,12 @@ requires = ["setuptools>=42", "ninja"] build-backend = "setuptools.build_meta" +[tool.pytest.ini_options] +log_cli = false +log_cli_level = "DEBUG" +log_cli_format = "[%(filename)s:%(funcName)s:%(lineno)d] %(message)s" +log_cli_date_format = "%Y-%m-%d %H:%M:%S" + [tool.cibuildwheel] before-build = "pip install -r build-requirements.txt -r requirements.txt -v" # HOLY FUCK