From 85af31f2b7b33360f83bfe040a09ae8c6e87d928 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 10 Aug 2023 14:43:48 +0200 Subject: [PATCH] fix micro cond gen in training --- training/train_muse.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/training/train_muse.py b/training/train_muse.py index d5e6ef2e..132ef720 100644 --- a/training/train_muse.py +++ b/training/train_muse.py @@ -940,8 +940,9 @@ def generate_images( clip_embeds = None if config.model.transformer.get("add_micro_cond_embeds", False): + resolution = config.dataset.preprocessing.resolution micro_conds = torch.tensor( - [256, 256, 0, 0, 6], device=encoder_hidden_states.device, dtype=encoder_hidden_states.dtype + [resolution, resolution, 0, 0, 6], device=encoder_hidden_states.device, dtype=encoder_hidden_states.dtype ) micro_conds = micro_conds.unsqueeze(0).repeat(encoder_hidden_states.shape[0], 1)