From c385ea227aeefd7929d543c7e830e479aff902ed Mon Sep 17 00:00:00 2001 From: rizavelioglu Date: Wed, 8 Jan 2025 19:27:32 +0100 Subject: [PATCH] fix: incomplete inference When number of images to run inference for is not divisible by the batch size, the inference for the last batch failed. This is because how PNDMScheduler is implemented. Fix is achieved by calling the scheduler's method for each batch, which resets the scheduler state before each batch. --- tryoffdiff/modeling/predict.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tryoffdiff/modeling/predict.py b/tryoffdiff/modeling/predict.py index 37d7a3b..90f5b0a 100644 --- a/tryoffdiff/modeling/predict.py +++ b/tryoffdiff/modeling/predict.py @@ -35,7 +35,6 @@ def inference_tryoffdiff(config: InferenceConfig): # Set up scheduler scheduler = PNDMScheduler.from_pretrained(config.scheduler_dir) - scheduler.set_timesteps(config.num_inference_steps) # Prepare dataloader val_set = VitonVAESigLIPDataset(root_dir=config.val_img_dir, inference=True) @@ -61,6 +60,8 @@ def inference_tryoffdiff(config: InferenceConfig): for cond, img_name in tqdm(val_loader, desc="Processing batches"): cond = cond.to(config.device) batch_size = cond.size(0) # Adjust batch size for the last batch + scheduler.set_timesteps(config.num_inference_steps) # Reset scheduler for each batch + # Initialize noise for this batch x = torch.randn(batch_size, 4, 64, 64, generator=generator, device=config.device) uncond = torch.zeros_like(cond) if config.guidance_scale else None