Skip to content

Commit

Permalink
log_diffusion_images.py fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jazcollins committed Oct 2, 2023
1 parent 2346076 commit 8d8e2ae
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,18 @@ def eval_batch_end(self, state: State, logger: Logger):
return_tensors='pt')['input_ids'] # type: ignore
for p in self.prompts
]

if model.sdxl:
self.tokenized_prompts = [
torch.cat([tp[0] for tp in self.tokenized_prompts]),
torch.cat([tp[1] for tp in self.tokenized_prompts])
]
else:
self.tokenized_prompts = torch.cat(self.tokenized_prompts) # type: ignore
if model.sdxl:
tokenized_prompts_1 = torch.cat([tp[0] for tp in self.tokenized_prompts
]).to(state.batch[self.text_key].device)
tokenized_prompts_2 = torch.cat([tp[1] for tp in self.tokenized_prompts
]).to(state.batch[self.text_key].device)
self.tokenized_prompts = [tokenized_prompts_1, tokenized_prompts_2]
self.tokenized_prompts[0] = self.tokenized_prompts[0].to(state.batch[self.text_key].device)
self.tokenized_prompts[1] = self.tokenized_prompts[1].to(state.batch[self.text_key].device)
else:
self.tokenized_prompts = torch.cat(self.tokenized_prompts) # type: ignore
self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device)
self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # type: ignore

# Generate images
with get_precision_context(state.precision):
Expand Down

0 comments on commit 8d8e2ae

Please sign in to comment.