Skip to content

Commit

Permalink
Update diffusion/models/stable_diffusion.py
Browse files Browse the repository at this point in the history
Co-authored-by: Landan Seguin <[email protected]>
  • Loading branch information
jazcollins and Landanjs authored Oct 2, 2023
1 parent dd35c77 commit 45c0cd7
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions diffusion/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,8 @@ def generate(
crop_params = [0., 0.]
if not size_params:
size_params = [width, height]
cond_original_size = torch.tensor([[width, height]]).repeat(pooled_embeddings.shape[0],
1).to(device).float()
cond_crops_coords_top_left = torch.tensor([crop_params]).repeat(pooled_embeddings.shape[0],
1).to(device).float()
cond_target_size = torch.tensor([size_params]).repeat(pooled_embeddings.shape[0], 1).to(device).float()
add_time_ids = torch.cat([cond_original_size, cond_crops_coords_top_left, cond_target_size], dim=1).float()
add_time_ids = torch.tensor([[width, height, *crop_params, *size_params]], dtype=torch.float, device=device)
add_time_ids = add_time_ids.repeat(pooled_embeddings.shape[0], 1)
add_text_embeds = pooled_embeddings

added_cond_kwargs = {'text_embeds': add_text_embeds, 'time_ids': add_time_ids}
Expand Down

0 comments on commit 45c0cd7

Please sign in to comment.