Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
jazcollins committed Oct 2, 2023
1 parent 45c0cd7 commit 2346076
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 24 deletions.
13 changes: 8 additions & 5 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,21 @@ def eval_batch_end(self, state: State, logger: Logger):
model = state.model

if self.tokenized_prompts is None:
tokenized_prompts = [
self.tokenized_prompts = [
model.tokenizer(p, padding='max_length', truncation=True,
return_tensors='pt')['input_ids'] # type: ignore
for p in self.prompts
]

if model.sdxl:
tokenized_prompts_1 = torch.cat([tp[0] for tp in tokenized_prompts]).to(state.batch[self.text_key].device)
tokenized_prompts_2 = torch.cat([tp[1] for tp in tokenized_prompts]).to(state.batch[self.text_key].device)
tokenized_prompts_1 = torch.cat([tp[0] for tp in self.tokenized_prompts
]).to(state.batch[self.text_key].device)
tokenized_prompts_2 = torch.cat([tp[1] for tp in self.tokenized_prompts
]).to(state.batch[self.text_key].device)
self.tokenized_prompts = [tokenized_prompts_1, tokenized_prompts_2]
else:
self.tokenized_prompts = torch.cat(tokenized_prompts)
self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # type: ignore
self.tokenized_prompts = torch.cat(self.tokenized_prompts) # type: ignore
self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device)

# Generate images
with get_precision_context(state.precision):
Expand Down
13 changes: 6 additions & 7 deletions diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,12 @@ def __getitem__(self, index):
if isinstance(caption, List) and self.caption_selection == 'random':
caption = random.sample(caption, k=1)[0]

max_length = None if self.sdxl else self.tokenizer.model_max_length
tokenized_caption = self.tokenizer(
caption,
padding='max_length',
max_length=max_length,
truncation=True,
return_tensors='pt')['input_ids']
max_length = None if self.sdxl else self.tokenizer.model_max_length # type: ignore
tokenized_caption = self.tokenizer(caption,
padding='max_length',
max_length=max_length,
truncation=True,
return_tensors='pt')['input_ids']
if self.sdxl:
tokenized_caption = [tokenized_cap.squeeze() for tokenized_cap in tokenized_caption]
tokenized_caption = torch.stack(tokenized_caption)
Expand Down
2 changes: 1 addition & 1 deletion diffusion/datasets/laion/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""Transforms for the training and eval dataset."""

import torchvision.transforms as transforms
from torchvision.transforms.functional import crop
from torchvision.transforms import RandomCrop
from torchvision.transforms.functional import crop


class LargestCenterSquare:
Expand Down
24 changes: 13 additions & 11 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,17 +427,19 @@ def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0'):
self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer_2')

def __call__(self, prompt, padding, truncation, return_tensors, max_length=None):
tokenized_output = self.tokenizer(prompt,
padding=padding,
max_length=self.tokenizer.model_max_length if max_length is None else max_length,
truncation=truncation,
return_tensors=return_tensors)
tokenized_output_2 = self.tokenizer_2(prompt,
padding=padding,
max_length=self.tokenizer_2.model_max_length if max_length is None else max_length,
truncation=truncation,
return_tensors=return_tensors)

tokenized_output = self.tokenizer(
prompt,
padding=padding,
max_length=self.tokenizer.model_max_length if max_length is None else max_length,
truncation=truncation,
return_tensors=return_tensors)
tokenized_output_2 = self.tokenizer_2(
prompt,
padding=padding,
max_length=self.tokenizer_2.model_max_length if max_length is None else max_length,
truncation=truncation,
return_tensors=return_tensors)

# Add second tokenizer output to first tokenizer
for key in tokenized_output.keys():
tokenized_output[key] = [tokenized_output[key], tokenized_output_2[key]]
Expand Down

0 comments on commit 2346076

Please sign in to comment.