Skip to content

Commit

Permalink
fix: incomplete inference
Browse files Browse the repository at this point in the history
When number of images to run inference for is not divisible by the batch size, the inference for the last batch failed. This is because how PNDMScheduler is implemented. Fix is achieved by calling the scheduler's  method for each batch, which resets the scheduler state before each batch.
  • Loading branch information
rizavelioglu committed Jan 8, 2025
1 parent 9a0fe4b commit c385ea2
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tryoffdiff/modeling/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def inference_tryoffdiff(config: InferenceConfig):

# Set up scheduler
scheduler = PNDMScheduler.from_pretrained(config.scheduler_dir)
scheduler.set_timesteps(config.num_inference_steps)

# Prepare dataloader
val_set = VitonVAESigLIPDataset(root_dir=config.val_img_dir, inference=True)
Expand All @@ -61,6 +60,8 @@ def inference_tryoffdiff(config: InferenceConfig):
for cond, img_name in tqdm(val_loader, desc="Processing batches"):
cond = cond.to(config.device)
batch_size = cond.size(0) # Adjust batch size for the last batch
scheduler.set_timesteps(config.num_inference_steps) # Reset scheduler for each batch

# Initialize noise for this batch
x = torch.randn(batch_size, 4, 64, 64, generator=generator, device=config.device)
uncond = torch.zeros_like(cond) if config.guidance_scale else None
Expand Down

0 comments on commit c385ea2

Please sign in to comment.