Skip to content

Commit

Permalink
Migrate to python functions for gsplat (#2778)
Browse files Browse the repository at this point in the history
* switch to pythonic interface for gsplat
  • Loading branch information
kerrj authored Jan 18, 2024
1 parent a78ca29 commit b76a240
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 20 deletions.
36 changes: 17 additions & 19 deletions nerfstudio/models/gaussian_splatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@
import numpy as np
import torch
from gsplat._torch_impl import quat_to_rotmat
from gsplat.compute_cumulative_intersects import compute_cumulative_intersects
from gsplat.project_gaussians import ProjectGaussians
from gsplat.rasterize import RasterizeGaussians
from gsplat.sh import SphericalHarmonics, num_sh_bases
from gsplat.project_gaussians import project_gaussians
from gsplat.rasterize import rasterize_gaussians
from gsplat.sh import num_sh_bases, spherical_harmonics
from pytorch_msssim import SSIM
from torch.nn import Parameter

Expand Down Expand Up @@ -324,7 +323,8 @@ def after_train(self, step: int):
with torch.no_grad():
# keep track of a moving average of grad norms
visible_mask = (self.radii > 0).flatten()
grads = self.xys.grad.detach().norm(dim=-1) # TODO fill in
assert self.xys.grad is not None
grads = self.xys.grad.detach().norm(dim=-1)
# print(f"grad norm min {grads.min().item()} max {grads.max().item()} mean {grads.mean().item()} size {grads.shape}")
if self.xys_grad_norm is None:
self.xys_grad_norm = grads
Expand Down Expand Up @@ -629,13 +629,13 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
cy = camera.cy.item()
fovx = 2 * math.atan(camera.width / (2 * camera.fx))
fovy = 2 * math.atan(camera.height / (2 * camera.fy))
W, H = camera.width.item(), camera.height.item()
W, H = int(camera.width.item()), int(camera.height.item())
self.last_size = (H, W)
projmat = projection_matrix(0.001, 1000, fovx, fovy, device=self.device)
BLOCK_X, BLOCK_Y = 16, 16
tile_bounds = (
(W + BLOCK_X - 1) // BLOCK_X,
(H + BLOCK_Y - 1) // BLOCK_Y,
int((W + BLOCK_X - 1) // BLOCK_X),
int((H + BLOCK_Y - 1) // BLOCK_Y),
1,
)

Expand All @@ -656,7 +656,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:

colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1)

self.xys, depths, self.radii, conics, num_tiles_hit, cov3d = ProjectGaussians.apply(
self.xys, depths, self.radii, conics, num_tiles_hit, cov3d = project_gaussians( # type: ignore
means_crop,
torch.exp(scales_crop),
1,
Expand All @@ -682,44 +682,42 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
viewdirs = means_crop.detach() - camera.camera_to_worlds.detach()[..., :3, 3] # (N, 3)
viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree)
rgbs = SphericalHarmonics.apply(n, viewdirs, colors_crop)
rgbs = spherical_harmonics(n, viewdirs, colors_crop)
rgbs = torch.clamp(rgbs + 0.5, min=0.0) # type: ignore
else:
rgbs = torch.sigmoid(colors_crop[:, 0, :])

# rescale the camera back to original dimensions
camera.rescale_output_resolution(camera_downscale)

# avoid empty rasterization
num_intersects, _ = compute_cumulative_intersects(self.xys.size(0), num_tiles_hit)
assert num_intersects > 0
assert (num_tiles_hit > 0).any() # type: ignore

rgb = RasterizeGaussians.apply(
rgb = rasterize_gaussians( # type: ignore
self.xys,
depths,
self.radii,
conics,
num_tiles_hit,
num_tiles_hit, # type: ignore
rgbs,
torch.sigmoid(opacities_crop),
H,
W,
background,
background=background,
) # type: ignore
rgb = torch.clamp(rgb, max=1.0) # type: ignore
depth_im = None
if not self.training:
depth_im = RasterizeGaussians.apply( # type: ignore
depth_im = rasterize_gaussians( # type: ignore
self.xys,
depths,
self.radii,
conics,
num_tiles_hit,
num_tiles_hit, # type: ignore
depths[:, None].repeat(1, 3),
torch.sigmoid(opacities_crop),
H,
W,
torch.ones(3, device=self.device) * 10,
background=torch.ones(3, device=self.device) * 10,
)[..., 0:1] # type: ignore

return {"rgb": rgb, "depth": depth_im} # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ dependencies = [
"xatlas",
"trimesh>=3.20.2",
"timm==0.6.7",
"gsplat==0.1.0",
"gsplat==0.1.2.1",
"pytorch-msssim",
"pathos"
]
Expand Down

0 comments on commit b76a240

Please sign in to comment.