Skip to content

Commit

Permalink
Fix incorrect dimension order in image prompts (#94)
Browse files Browse the repository at this point in the history
* Fix incorrect dimension order in image prompts

This wasn't noticed before because the images are square.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
neverix and pre-commit-ci[bot] authored Jan 10, 2022
1 parent 7ef963d commit 7cce1f6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions rudalle/image_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def _preprocess_img(self, pil_img):

def _get_image_prompts(self, img, borders, vae, crop_first):
if crop_first:
bs, _, img_w, img_h = img.shape
bs, _, img_h, img_w = img.shape
vqg_img_w, vqg_img_h = img_w // 8, img_h // 8
vqg_img = torch.zeros((bs, vqg_img_w, vqg_img_h), dtype=torch.int32, device=img.device)
vqg_img = torch.zeros((bs, vqg_img_h, vqg_img_w), dtype=torch.int32, device=img.device)
if borders['down'] != 0:
down_border = borders['down'] * 8
_, _, [_, _, down_vqg_img] = vae.model.encode(img[:, :, -down_border:, :])
Expand All @@ -49,8 +49,8 @@ def _get_image_prompts(self, img, borders, vae, crop_first):
else:
_, _, [_, _, vqg_img] = vae.model.encode(img)

bs, vqg_img_w, vqg_img_h = vqg_img.shape
mask = torch.zeros(vqg_img_w, vqg_img_h)
bs, vqg_img_h, vqg_img_w = vqg_img.shape
mask = torch.zeros(vqg_img_h, vqg_img_w)
if borders['up'] != 0:
mask[:borders['up'], :] = 1.
if borders['down'] != 0:
Expand Down

0 comments on commit 7cce1f6

Please sign in to comment.