diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index 84000e73c5..b23eabe002 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -779,8 +779,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: W, background=torch.zeros(3, device=self.device), )[..., 0:1] # type: ignore - depth_im[alpha > 0] = depth_im[alpha > 0] / alpha[alpha > 0] - depth_im[alpha == 0] = 1000 + depth_im = torch.where(alpha > 0, depth_im / alpha, depth_im.detach().max()) return {"rgb": rgb, "depth": depth_im, "accumulation": alpha} # type: ignore