Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
liruilong940607 committed Jan 9, 2024
1 parent cc98282 commit 8642e0c
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 51 deletions.
105 changes: 67 additions & 38 deletions gsplat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,29 @@
import torch
from .project_gaussians import project_gaussians
from .rasterize import rasterize_gaussians
from .utils import map_gaussian_to_intersects, bin_and_sort_gaussians, compute_cumulative_intersects, compute_cov2d_bounds, get_tile_bin_edges
from .utils import (
map_gaussian_to_intersects,
bin_and_sort_gaussians,
compute_cumulative_intersects,
compute_cov2d_bounds,
get_tile_bin_edges,
)
from .sh import spherical_harmonics
from .version import __version__
import warnings


__all__ = [
"__version__",

"project_gaussians",
"rasterize_gaussians",
"spherical_harmonics",

# utils
"bin_and_sort_gaussians",
"compute_cumulative_intersects",
"compute_cov2d_bounds",
"get_tile_bin_edges",
"map_gaussian_to_intersects",

# Function.apply() will be deprecated
"ProjectGaussians",
"RasterizeGaussians",
Expand All @@ -36,102 +39,128 @@

# Define these for backwards compatibility

class MapGaussiansToIntersects(torch.autograd.Function):

class MapGaussiansToIntersects(torch.autograd.Function):
@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("MapGaussiansToIntersects is deprecated, use map_gaussian_to_intersects instead", DeprecationWarning)
warnings.warn(
"MapGaussiansToIntersects is deprecated, use map_gaussian_to_intersects instead",
DeprecationWarning,
)
return map_gaussian_to_intersects(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError



class ComputeCumulativeIntersects(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("ComputeCumulativeIntersects is deprecated, use compute_cumulative_intersects instead", DeprecationWarning)
warnings.warn(
"ComputeCumulativeIntersects is deprecated, use compute_cumulative_intersects instead",
DeprecationWarning,
)
return compute_cumulative_intersects(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError



class ComputeCov2dBounds(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("ComputeCov2dBounds is deprecated, use compute_cov2d_bounds instead", DeprecationWarning)
warnings.warn(
"ComputeCov2dBounds is deprecated, use compute_cov2d_bounds instead",
DeprecationWarning,
)
return compute_cov2d_bounds(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError


class GetTileBinEdges(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("GetTileBinEdges is deprecated, use get_tile_bin_edges instead", DeprecationWarning)
warnings.warn(
"GetTileBinEdges is deprecated, use get_tile_bin_edges instead",
DeprecationWarning,
)
return get_tile_bin_edges(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError



class BinAndSortGaussians(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("BinAndSortGaussians is deprecated, use bin_and_sort_gaussians instead", DeprecationWarning)
warnings.warn(
"BinAndSortGaussians is deprecated, use bin_and_sort_gaussians instead",
DeprecationWarning,
)
return bin_and_sort_gaussians(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError



class ProjectGaussians(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("ProjectGaussians is deprecated, use project_gaussians instead", DeprecationWarning)
warnings.warn(
"ProjectGaussians is deprecated, use project_gaussians instead",
DeprecationWarning,
)
return project_gaussians(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError



class RasterizeGaussians(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("RasterizeGaussians is deprecated, use rasterize_gaussians instead", DeprecationWarning)
warnings.warn(
"RasterizeGaussians is deprecated, use rasterize_gaussians instead",
DeprecationWarning,
)
return rasterize_gaussians(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError



class NDRasterizeGaussians(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("NDRasterizeGaussians is deprecated, use rasterize_gaussians instead", DeprecationWarning)
warnings.warn(
"NDRasterizeGaussians is deprecated, use rasterize_gaussians instead",
DeprecationWarning,
)
return rasterize_gaussians(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError

class SphericalHarmonics(torch.autograd.Function):

class SphericalHarmonics(torch.autograd.Function):
@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("SphericalHarmonics is deprecated, use spherical_harmonics instead", DeprecationWarning)
warnings.warn(
"SphericalHarmonics is deprecated, use spherical_harmonics instead",
DeprecationWarning,
)
return spherical_harmonics(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError

5 changes: 3 additions & 2 deletions gsplat/project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import gsplat.cuda as _C


def project_gaussians(
means3d: Float[Tensor, "*batch 3"],
scales: Float[Tensor, "*batch 3"],
Expand All @@ -28,7 +29,7 @@ def project_gaussians(
Note:
This function is differentiable w.r.t the means3d, scales and quats inputs.
Args:
means3d (Tensor): xyzs of gaussians.
scales (Tensor): scales of the gaussians.
Expand Down Expand Up @@ -210,4 +211,4 @@ def backward(ctx, v_xys, v_depths, v_radii, v_conics, v_num_tiles_hit, v_cov3d):
None,
# clip_thresh,
None,
)
)
9 changes: 4 additions & 5 deletions gsplat/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def rasterize_gaussians(
background.shape[0] == colors.shape[-1]
), f"incorrect shape of background color tensor, expected shape {colors.shape[-1]}"
else:
background = torch.ones(colors.shape[-1], dtype=torch.float32, device=colors.device)
background = torch.ones(
colors.shape[-1], dtype=torch.float32, device=colors.device
)

if xys.ndimension() != 2 or xys.size(1) != 2:
raise ValueError("xys must have dimensions (N, 2)")
Expand Down Expand Up @@ -103,9 +105,7 @@ def forward(
block = (BLOCK_X, BLOCK_Y, 1)
img_size = (img_width, img_height, 1)

num_intersects, cum_tiles_hit = compute_cumulative_intersects(
num_tiles_hit
)
num_intersects, cum_tiles_hit = compute_cumulative_intersects(num_tiles_hit)

(
isect_ids_unsorted,
Expand Down Expand Up @@ -198,4 +198,3 @@ def backward(ctx, v_out_img):
None, # img_width
None, # background
)

9 changes: 6 additions & 3 deletions gsplat/sh.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def deg_from_sh(num_bases: int):
return 4
assert False, "Invalid number of SH bases"


def spherical_harmonics(
degrees_to_use: int,
viewdirs: Float[Tensor, "*batch 3"],
Expand All @@ -46,12 +47,15 @@ def spherical_harmonics(
degrees_to_use (int): degree of SHs to use (<= total number available).
viewdirs (Tensor): viewing directions.
coeffs (Tensor): harmonic coefficients.
Returns:
The spherical harmonics.
"""
assert coeffs.shape[-2] >= num_sh_bases(degrees_to_use)
return _SphericalHarmonics.apply(degrees_to_use, viewdirs.contiguous(), coeffs.contiguous())
return _SphericalHarmonics.apply(
degrees_to_use, viewdirs.contiguous(), coeffs.contiguous()
)


class _SphericalHarmonics(Function):
"""Compute spherical harmonics
Expand Down Expand Up @@ -91,4 +95,3 @@ def backward(ctx, v_colors: Float[Tensor, "*batch 3"]):
num_points, degree, degrees_to_use, viewdirs, v_colors
),
)

13 changes: 10 additions & 3 deletions gsplat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,17 @@ def map_gaussian_to_intersects(
- **gaussian_ids** (Tensor): Tensor that maps isect_ids back to cum_tiles_hit.
"""
isect_ids, gaussian_ids = _C.map_gaussian_to_intersects(
num_points, num_intersects, xys.contiguous(), depths.contiguous(), radii.contiguous(), cum_tiles_hit.contiguous(), tile_bounds
num_points,
num_intersects,
xys.contiguous(),
depths.contiguous(),
radii.contiguous(),
cum_tiles_hit.contiguous(),
tile_bounds,
)
return (isect_ids, gaussian_ids)


def get_tile_bin_edges(
num_intersects: int, isect_ids_sorted: Int[Tensor, "num_intersects 1"]
) -> Int[Tensor, "num_intersects 2"]:
Expand Down Expand Up @@ -96,7 +103,7 @@ def compute_cumulative_intersects(
Note:
This function is not differentiable to any input.
Args:
num_tiles_hit (Tensor): number of intersected tiles per gaussian.
Expand Down Expand Up @@ -128,7 +135,7 @@ def bin_and_sort_gaussians(
]:
"""Mapping gaussians to sorted unique intersection IDs and tile bins used for fast rasterization.
We return both sorted and unsorted versions of intersect IDs and gaussian IDs for testing purposes.
We return both sorted and unsorted versions of intersect IDs and gaussian IDs for testing purposes.
Note:
This function is not differentiable to any input.
Expand Down

0 comments on commit 8642e0c

Please sign in to comment.