Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
liruilong940607 committed Jan 9, 2024
1 parent 0527c6e commit 857f407
Show file tree
Hide file tree
Showing 17 changed files with 554 additions and 615 deletions.
4 changes: 2 additions & 2 deletions docs/source/apis/rast.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ And
:math:`σ ∈ R^{2}` is the Mahalanobis distance (here referred to as sigma) which measures how many standard deviations away the center of a gaussian and the rendered pixel center is which is denoted by delta :math:`∆.`

The python bindings support conventional 3-channel RGB rasterization with :func:`gsplat.rasterize_gaussians` as well as N-dimensional rasterization with :func:`gsplat.ndrasterize_gaussians`.
The python bindings support conventional 3-channel RGB rasterization with :func:`gsplat.rasterize_gaussians` as well as N-dimensional rasterization with :func:`gsplat.nd_rasterize_gaussians`.


.. autofunction:: rasterize_gaussians

.. autofunction:: ndrasterize_gaussians
.. autofunction:: nd_rasterize_gaussians
2 changes: 1 addition & 1 deletion docs/source/apis/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ In addition to the main projection and rasterization functions, a few CUDA kerne

.. currentmodule:: gsplat

.. autofunction:: bin_an_sort_gaussians
.. autofunction:: bin_and_sort_gaussians

.. autofunction:: compute_cov2d_bounds

Expand Down
4 changes: 2 additions & 2 deletions examples/test_rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import tyro
from gsplat.project_gaussians import project_gaussians
from gsplat.rasterize import rasterize_gaussians
from gsplat.nd_rasterize import ndrasterize_gaussians
from gsplat.nd_rasterize import nd_rasterize_gaussians
from PIL import Image
from torch import Tensor, optim

Expand Down Expand Up @@ -138,7 +138,7 @@ def forward_slow(self):
self.tile_bounds,
)

return ndrasterize_gaussians(
return nd_rasterize_gaussians(
xys,
depths,
radii,
Expand Down
147 changes: 127 additions & 20 deletions gsplat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,139 @@
from .project_gaussians import ProjectGaussians, project_gaussians
from .rasterize import RasterizeGaussians, rasterize_gaussians
from .bin_and_sort_gaussians import BinAndSortGaussians, bin_and_sort_gaussians
from .compute_cumulative_intersects import ComputeCumulativeIntersects, compute_cumulative_intersects
from .cov2d_bounds import ComputeCov2dBounds, compute_cov2d_bounds
from .get_tile_bin_edges import GetTileBinEdges, get_tile_bin_edges
from .map_gaussian_to_intersects import MapGaussiansToIntersects, map_gaussian_to_intersects
from .sh import SphericalHarmonics, spherical_harmonics
from .nd_rasterize import NDRasterizeGaussians, ndrasterize_gaussians
from typing import Any
import torch
from .project_gaussians import project_gaussians
from .rasterize import rasterize_gaussians
from .utils import map_gaussian_to_intersects, bin_and_sort_gaussians, compute_cumulative_intersects, compute_cov2d_bounds, get_tile_bin_edges
from .sh import spherical_harmonics
from .nd_rasterize import nd_rasterize_gaussians
from .version import __version__
import warnings


__all__ = [
"__version__",
"ProjectGaussians", # deprecated

"project_gaussians",
"RasterizeGaussians", # deprecated
"rasterize_gaussians",
"BinAndSortGaussians", # deprecated
"nd_rasterize_gaussians",
"spherical_harmonics",

# utils
"bin_and_sort_gaussians",
"ComputeCumulativeIntersects", # deprecated
"compute_cumulative_intersects",
"ComputeCov2dBounds", # deprecated
"compute_cov2d_bounds",
"GetTileBinEdges", # deprecated
"get_tile_bin_edges",
"MapGaussiansToIntersects", # deprecated
"map_gaussian_to_intersects",
"SphericalHarmonics", # deprecated
"spherical_harmonics",
"NDRasterizeGaussians", # deprecated
"ndrasterize_gaussians",

# Function.apply() will be deprecated
"ProjectGaussians",
"RasterizeGaussians",
"BinAndSortGaussians",
"ComputeCumulativeIntersects",
"ComputeCov2dBounds",
"GetTileBinEdges",
"MapGaussiansToIntersects",
"SphericalHarmonics",
"NDRasterizeGaussians",
]

# Define these for backwards compatibility

class MapGaussiansToIntersects(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("MapGaussiansToIntersects is deprecated, use map_gaussian_to_intersects instead", DeprecationWarning)
return map_gaussian_to_intersects(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError

class ComputeCumulativeIntersects(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("ComputeCumulativeIntersects is deprecated, use compute_cumulative_intersects instead", DeprecationWarning)
return compute_cumulative_intersects(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError

class ComputeCov2dBounds(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("ComputeCov2dBounds is deprecated, use compute_cov2d_bounds instead", DeprecationWarning)
return compute_cov2d_bounds(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError

class GetTileBinEdges(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("GetTileBinEdges is deprecated, use get_tile_bin_edges instead", DeprecationWarning)
return get_tile_bin_edges(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError

class BinAndSortGaussians(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("BinAndSortGaussians is deprecated, use bin_and_sort_gaussians instead", DeprecationWarning)
return bin_and_sort_gaussians(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError

class ProjectGaussians(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("ProjectGaussians is deprecated, use project_gaussians instead", DeprecationWarning)
return project_gaussians(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError

class RasterizeGaussians(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("RasterizeGaussians is deprecated, use rasterize_gaussians instead", DeprecationWarning)
return rasterize_gaussians(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError

class NDRasterizeGaussians(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("NDRasterizeGaussians is deprecated, use nd_rasterize_gaussians instead", DeprecationWarning)
return nd_rasterize_gaussians(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError

class SphericalHarmonics(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
warnings.warn("SphericalHarmonics is deprecated, use spherical_harmonics instead", DeprecationWarning)
return spherical_harmonics(*args, **kwargs)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError

99 changes: 0 additions & 99 deletions gsplat/bin_and_sort_gaussians.py

This file was deleted.

48 changes: 0 additions & 48 deletions gsplat/compute_cumulative_intersects.py

This file was deleted.

Loading

0 comments on commit 857f407

Please sign in to comment.