Skip to content

Commit

Permalink
Equally spaced timesteps during eval
Browse files Browse the repository at this point in the history
  • Loading branch information
corystephenson-db committed Jul 25, 2024
1 parent f62f647 commit 0df0095
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions diffusion/models/t2i_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.nn.functional as F
from composer.models import ComposerModel
from composer.utils import dist
from torchmetrics import MeanSquaredError
from tqdm.auto import tqdm

Expand Down Expand Up @@ -279,11 +280,20 @@ def embed_tokenized_prompts(

def diffusion_forward_process(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Diffusion forward process using a rectified flow."""
# First, sample timesteps according to a logit-normal distribution
u = torch.randn(inputs.shape[0], device=inputs.device, generator=self.rng_generator, dtype=inputs.dtype)
u = self.timestep_mean + self.timestep_std * u
timesteps = torch.sigmoid(u).view(-1, 1, 1)
timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps)
if not self.model.training:
# Sample equally spaced timesteps across all devices
global_batch_size = inputs.shape[0] * dist.get_world_size()
global_timesteps = torch.linspace(0, 1, global_batch_size)
# Get this device's subset of all the timesteps
idx_offset = dist.get_global_rank() * inputs.shape[0]
timesteps = global_timesteps[idx_offset:idx_offset + inputs.shape[0]].to(inputs.device)
timesteps = timesteps.view(-1, 1, 1)
else:
# Sample timesteps according to a logit-normal distribution
u = torch.randn(inputs.shape[0], device=inputs.device, generator=self.rng_generator, dtype=inputs.dtype)
u = self.timestep_mean + self.timestep_std * u
timesteps = torch.sigmoid(u).view(-1, 1, 1)
timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps)
# Then, add the noise to the latents according to the recitified flow
noise = torch.randn(*inputs.shape, device=inputs.device, generator=self.rng_generator)
noised_inputs = (1 - timesteps) * inputs + timesteps * noise
Expand Down

0 comments on commit 0df0095

Please sign in to comment.