diff --git a/tryoffdiff/modeling/predict.py b/tryoffdiff/modeling/predict.py index 37d7a3b..90f5b0a 100644 --- a/tryoffdiff/modeling/predict.py +++ b/tryoffdiff/modeling/predict.py @@ -35,7 +35,6 @@ def inference_tryoffdiff(config: InferenceConfig): # Set up scheduler scheduler = PNDMScheduler.from_pretrained(config.scheduler_dir) - scheduler.set_timesteps(config.num_inference_steps) # Prepare dataloader val_set = VitonVAESigLIPDataset(root_dir=config.val_img_dir, inference=True) @@ -61,6 +60,8 @@ def inference_tryoffdiff(config: InferenceConfig): for cond, img_name in tqdm(val_loader, desc="Processing batches"): cond = cond.to(config.device) batch_size = cond.size(0) # Adjust batch size for the last batch + scheduler.set_timesteps(config.num_inference_steps) # Reset scheduler for each batch + # Initialize noise for this batch x = torch.randn(batch_size, 4, 64, 64, generator=generator, device=config.device) uncond = torch.zeros_like(cond) if config.guidance_scale else None