Skip to content

Commit c385ea2

Browse files
committed
fix: incomplete inference
When number of images to run inference for is not divisible by the batch size, the inference for the last batch failed. This is because how PNDMScheduler is implemented. Fix is achieved by calling the scheduler's method for each batch, which resets the scheduler state before each batch.
1 parent 9a0fe4b commit c385ea2

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tryoffdiff/modeling/predict.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def inference_tryoffdiff(config: InferenceConfig):
3535

3636
# Set up scheduler
3737
scheduler = PNDMScheduler.from_pretrained(config.scheduler_dir)
38-
scheduler.set_timesteps(config.num_inference_steps)
3938

4039
# Prepare dataloader
4140
val_set = VitonVAESigLIPDataset(root_dir=config.val_img_dir, inference=True)
@@ -61,6 +60,8 @@ def inference_tryoffdiff(config: InferenceConfig):
6160
for cond, img_name in tqdm(val_loader, desc="Processing batches"):
6261
cond = cond.to(config.device)
6362
batch_size = cond.size(0) # Adjust batch size for the last batch
63+
scheduler.set_timesteps(config.num_inference_steps) # Reset scheduler for each batch
64+
6465
# Initialize noise for this batch
6566
x = torch.randn(batch_size, 4, 64, 64, generator=generator, device=config.device)
6667
uncond = torch.zeros_like(cond) if config.guidance_scale else None

0 commit comments

Comments
 (0)