Skip to content

Commit

Permalink
Speed up mask sampling with rejection sampling (#2585)
Browse files Browse the repository at this point in the history
* change masked pixel sampling to use rejection sampling instead of torch.nonzero

* black reformat code

* pyright unbound variable num_valid

* pyright type issues with num_valid

* add configuration settings for rejection sampling masks

* black reformat

* maybe this fixes it?

* revert behavior if mask sampling failed, still raise warning

* on iteration failure, use non-rejection sampling to generate indices

* ruff

---------

Co-authored-by: adrian_chang <[email protected]>
Co-authored-by: Alexander Kristoffersen <[email protected]>
  • Loading branch information
3 people authored Jan 18, 2024
1 parent 15e81d3 commit a78ca29
Showing 1 changed file with 42 additions and 8 deletions.
50 changes: 42 additions & 8 deletions nerfstudio/data/pixel_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import random
import warnings
from dataclasses import dataclass, field
from typing import Dict, Optional, Type, Union

Expand All @@ -42,6 +43,10 @@ class PixelSamplerConfig(InstantiateConfig):
"""List of whether or not camera i is equirectangular."""
fisheye_crop_radius: Optional[float] = None
"""Set to the radius (in pixels) for fisheye cameras."""
rejection_sample_mask: bool = True
"""Whether or not to use rejection sampling when sampling images with masks"""
max_num_iterations: int = 100
"""If rejection sampling masks, the maximum number of times to sample"""


class PixelSampler:
Expand Down Expand Up @@ -88,15 +93,44 @@ def sample_method(
num_images: number of images to sample over
mask: mask of possible pixels in an image to sample from.
"""
indices = (
torch.rand((batch_size, 3), device=device)
* torch.tensor([num_images, image_height, image_width], device=device)
).long()

if isinstance(mask, torch.Tensor):
nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False)
chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size)
indices = nonzero_indices[chosen_indices]
else:
indices = (
torch.rand((batch_size, 3), device=device)
* torch.tensor([num_images, image_height, image_width], device=device)
).long()
if self.config.rejection_sample_mask:
num_valid = 0
for _ in range(self.config.max_num_iterations):
c, y, x = (i.flatten() for i in torch.split(indices, 1, dim=-1))
chosen_indices_validity = mask[..., 0][c, y, x].bool()
num_valid = int(torch.sum(chosen_indices_validity).item())
if num_valid == batch_size:
break
else:
replacement_indices = (
torch.rand((batch_size - num_valid, 3), device=device)
* torch.tensor([num_images, image_height, image_width], device=device)
).long()
indices[~chosen_indices_validity] = replacement_indices

if num_valid != batch_size:
warnings.warn(
"""
Masked sampling failed, mask is either empty or mostly empty.
Reverting behavior to non-rejection sampling. Consider setting
pipeline.datamanager.pixel-sampler.rejection-sample-mask to False
or increasing pipeline.datamanager.pixel-sampler.max-num-iterations
"""
)
self.config.rejection_sample_mask = False
nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False)
chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size)
indices = nonzero_indices[chosen_indices]
else:
nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False)
chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size)
indices = nonzero_indices[chosen_indices]

return indices

Expand Down

0 comments on commit a78ca29

Please sign in to comment.