diff --git a/diffusion/callbacks/log_diffusion_images.py b/diffusion/callbacks/log_diffusion_images.py index f114a3bd..2b5a7c50 100644 --- a/diffusion/callbacks/log_diffusion_images.py +++ b/diffusion/callbacks/log_diffusion_images.py @@ -63,18 +63,21 @@ def eval_batch_end(self, state: State, logger: Logger): model = state.model if self.tokenized_prompts is None: - tokenized_prompts = [ + self.tokenized_prompts = [ model.tokenizer(p, padding='max_length', truncation=True, return_tensors='pt')['input_ids'] # type: ignore for p in self.prompts ] + if model.sdxl: - tokenized_prompts_1 = torch.cat([tp[0] for tp in tokenized_prompts]).to(state.batch[self.text_key].device) - tokenized_prompts_2 = torch.cat([tp[1] for tp in tokenized_prompts]).to(state.batch[self.text_key].device) + tokenized_prompts_1 = torch.cat([tp[0] for tp in self.tokenized_prompts + ]).to(state.batch[self.text_key].device) + tokenized_prompts_2 = torch.cat([tp[1] for tp in self.tokenized_prompts + ]).to(state.batch[self.text_key].device) self.tokenized_prompts = [tokenized_prompts_1, tokenized_prompts_2] else: - self.tokenized_prompts = torch.cat(tokenized_prompts) - self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # type: ignore + self.tokenized_prompts = torch.cat(self.tokenized_prompts) # type: ignore + self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # Generate images with get_precision_context(state.precision): diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 25136818..24cec6dd 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -129,13 +129,12 @@ def __getitem__(self, index): if isinstance(caption, List) and self.caption_selection == 'random': caption = random.sample(caption, k=1)[0] - max_length = None if self.sdxl else self.tokenizer.model_max_length - tokenized_caption = self.tokenizer( - caption, - padding='max_length', - max_length=max_length, - truncation=True, - return_tensors='pt')['input_ids'] + max_length = None if self.sdxl else self.tokenizer.model_max_length # type: ignore + tokenized_caption = self.tokenizer(caption, + padding='max_length', + max_length=max_length, + truncation=True, + return_tensors='pt')['input_ids'] if self.sdxl: tokenized_caption = [tokenized_cap.squeeze() for tokenized_cap in tokenized_caption] tokenized_caption = torch.stack(tokenized_caption) diff --git a/diffusion/datasets/laion/transforms.py b/diffusion/datasets/laion/transforms.py index c0ce01b8..86ed78ad 100644 --- a/diffusion/datasets/laion/transforms.py +++ b/diffusion/datasets/laion/transforms.py @@ -4,8 +4,8 @@ """Transforms for the training and eval dataset.""" import torchvision.transforms as transforms -from torchvision.transforms.functional import crop from torchvision.transforms import RandomCrop +from torchvision.transforms.functional import crop class LargestCenterSquare: diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 0b0a03e9..38e07c4c 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -427,17 +427,19 @@ def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0'): self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer_2') def __call__(self, prompt, padding, truncation, return_tensors, max_length=None): - tokenized_output = self.tokenizer(prompt, - padding=padding, - max_length=self.tokenizer.model_max_length if max_length is None else max_length, - truncation=truncation, - return_tensors=return_tensors) - tokenized_output_2 = self.tokenizer_2(prompt, - padding=padding, - max_length=self.tokenizer_2.model_max_length if max_length is None else max_length, - truncation=truncation, - return_tensors=return_tensors) - + tokenized_output = self.tokenizer( + prompt, + padding=padding, + max_length=self.tokenizer.model_max_length if max_length is None else max_length, + truncation=truncation, + return_tensors=return_tensors) + tokenized_output_2 = self.tokenizer_2( + prompt, + padding=padding, + max_length=self.tokenizer_2.model_max_length if max_length is None else max_length, + truncation=truncation, + return_tensors=return_tensors) + # Add second tokenizer output to first tokenizer for key in tokenized_output.keys(): tokenized_output[key] = [tokenized_output[key], tokenized_output_2[key]]