Skip to content

Commit

Permalink
add return_alpha keyword
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhuoyang-Pan committed Jan 14, 2024
1 parent f72dd2a commit 895ea08
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
Empty file removed gsplat/nd_rasterize.py
Empty file.
15 changes: 11 additions & 4 deletions gsplat/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def rasterize_gaussians(
img_height: int,
img_width: int,
background: Optional[Float[Tensor, "channels"]] = None,
return_alpha: Optional[bool] = False,
) -> Tensor:
"""Rasterizes 2D gaussians by sorting and binning gaussian intersections for each tile and returns an N-dimensional output using alpha-compositing.
Expand Down Expand Up @@ -75,6 +76,7 @@ def rasterize_gaussians(
img_height,
img_width,
background.contiguous(),
return_alpha,
)


Expand All @@ -94,6 +96,7 @@ def forward(
img_height: int,
img_width: int,
background: Optional[Float[Tensor, "channels"]] = None,
return_alpha: Optional[bool] = False,
) -> Tensor:
num_points = xys.size(0)
BLOCK_X, BLOCK_Y = 16, 16
Expand Down Expand Up @@ -147,9 +150,12 @@ def forward(
final_Ts,
final_idx,
)
out_alpha = 1 - final_Ts

return out_img, out_alpha

if return_alpha:
out_alpha = 1 - final_Ts
return out_img, out_alpha
else:
return out_img

@staticmethod
def backward(ctx, v_out_img, v_out_alpha=None):
Expand All @@ -158,7 +164,7 @@ def backward(ctx, v_out_img, v_out_alpha=None):

if v_out_alpha is None:
v_out_alpha = torch.zeros_like(v_out_img[..., 0])

print(v_out_alpha)
(
gaussian_ids_sorted,
tile_bins,
Expand Down Expand Up @@ -202,4 +208,5 @@ def backward(ctx, v_out_img, v_out_alpha=None):
None, # img_height
None, # img_width
None, # background
None, # return_alpha
)

0 comments on commit 895ea08

Please sign in to comment.