diff --git a/nerfstudio/data/pixel_samplers.py b/nerfstudio/data/pixel_samplers.py index 0fc58bde01..144e405a57 100644 --- a/nerfstudio/data/pixel_samplers.py +++ b/nerfstudio/data/pixel_samplers.py @@ -17,6 +17,7 @@ """ import random +import warnings from dataclasses import dataclass, field from typing import Dict, Optional, Type, Union @@ -42,6 +43,10 @@ class PixelSamplerConfig(InstantiateConfig): """List of whether or not camera i is equirectangular.""" fisheye_crop_radius: Optional[float] = None """Set to the radius (in pixels) for fisheye cameras.""" + rejection_sample_mask: bool = True + """Whether or not to use rejection sampling when sampling images with masks""" + max_num_iterations: int = 100 + """If rejection sampling masks, the maximum number of times to sample""" class PixelSampler: @@ -88,15 +93,44 @@ def sample_method( num_images: number of images to sample over mask: mask of possible pixels in an image to sample from. """ + indices = ( + torch.rand((batch_size, 3), device=device) + * torch.tensor([num_images, image_height, image_width], device=device) + ).long() + if isinstance(mask, torch.Tensor): - nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False) - chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size) - indices = nonzero_indices[chosen_indices] - else: - indices = ( - torch.rand((batch_size, 3), device=device) - * torch.tensor([num_images, image_height, image_width], device=device) - ).long() + if self.config.rejection_sample_mask: + num_valid = 0 + for _ in range(self.config.max_num_iterations): + c, y, x = (i.flatten() for i in torch.split(indices, 1, dim=-1)) + chosen_indices_validity = mask[..., 0][c, y, x].bool() + num_valid = int(torch.sum(chosen_indices_validity).item()) + if num_valid == batch_size: + break + else: + replacement_indices = ( + torch.rand((batch_size - num_valid, 3), device=device) + * torch.tensor([num_images, image_height, image_width], device=device) + ).long() + indices[~chosen_indices_validity] = replacement_indices + + if num_valid != batch_size: + warnings.warn( + """ + Masked sampling failed, mask is either empty or mostly empty. + Reverting behavior to non-rejection sampling. Consider setting + pipeline.datamanager.pixel-sampler.rejection-sample-mask to False + or increasing pipeline.datamanager.pixel-sampler.max-num-iterations + """ + ) + self.config.rejection_sample_mask = False + nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False) + chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size) + indices = nonzero_indices[chosen_indices] + else: + nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False) + chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size) + indices = nonzero_indices[chosen_indices] return indices