Skip to content

Commit

Permalink
Merge pull request #110 from williamberman/will/inpainting-fixes
Browse files Browse the repository at this point in the history
add fixes to inpainting pipeline
  • Loading branch information
williamberman authored Sep 1, 2023
2 parents ec965ed + 0f4171c commit 0ba7d01
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion muse/pipeline_muse.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ def __call__(
generator: Optional[torch.Generator] = None,
use_fp16: bool = False,
image_size: int = 256,
orig_size=(256, 256),
crop_coords=(0, 0),
aesthetic_score=6.0,
):
from torchvision import transforms

Expand All @@ -366,7 +369,9 @@ def __call__(
pixel_values = encode_transform(image).unsqueeze(0).to(self.device)
_, image_tokens = self.vae.encode(pixel_values)
mask_token_id = self.transformer.config.mask_token_id

image_tokens[mask[None]] = mask_token_id

image_tokens = image_tokens.repeat(num_images_per_prompt, 1)
if class_ids is not None:
if isinstance(class_ids, int):
Expand All @@ -388,7 +393,13 @@ def __call__(
max_length=self.tokenizer.model_max_length,
).input_ids # TODO: remove hardcode
input_ids = input_ids.to(self.device)
encoder_hidden_states = self.text_encoder(input_ids).last_hidden_state

if self.transformer.config.add_cond_embeds:
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
pooled_embeds, encoder_hidden_states = outputs.text_embeds, outputs.hidden_states[-2]
else:
encoder_hidden_states = self.text_encoder(input_ids).last_hidden_state
pooled_embeds = None

if negative_text is not None:
if isinstance(negative_text, str):
Expand Down Expand Up @@ -417,10 +428,27 @@ def __call__(
bs_embed * num_images_per_prompt, seq_len, -1
)

empty_input = self.tokenizer("", padding="max_length", return_tensors="pt").input_ids.to(
self.text_encoder.device
)
outputs = self.text_encoder(empty_input, output_hidden_states=True)
empty_embeds = outputs.hidden_states[-2]
empty_cond_embeds = outputs[0]

model_inputs = {
"encoder_hidden_states": encoder_hidden_states,
"negative_embeds": negative_encoder_hidden_states,
"empty_embeds": empty_embeds,
"empty_cond_embeds": empty_cond_embeds,
"cond_embeds": pooled_embeds,
}

if self.transformer.config.add_micro_cond_embeds:
micro_conds = list(orig_size) + list(crop_coords) + [aesthetic_score]
micro_conds = torch.tensor(micro_conds, device=self.device, dtype=encoder_hidden_states.dtype)
micro_conds = micro_conds.unsqueeze(0)
model_inputs["micro_conds"] = micro_conds

generate = self.transformer.generate2
with torch.autocast("cuda", enabled=use_fp16):
generated_tokens = generate(
Expand Down

0 comments on commit 0ba7d01

Please sign in to comment.